ant_quic/masque/
relay_client.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! MASQUE Relay Client
9//!
10//! Implements a client for connecting to MASQUE CONNECT-UDP Bind relays.
11//! Used when direct NAT traversal fails and relay fallback is needed.
12//!
13//! # Overview
14//!
15//! The relay client connects to a relay server and:
16//! - Negotiates a CONNECT-UDP Bind session
17//! - Learns its public address from the relay
18//! - Manages context registrations for efficient datagram forwarding
19//! - Provides a simple API for sending/receiving datagrams through the relay
20//!
21//! # Example
22//!
23//! ```rust,ignore
24//! use ant_quic::masque::relay_client::{MasqueRelayClient, RelayClientConfig};
25//! use std::net::SocketAddr;
26//!
27//! // Connect to a relay
28//! let relay_addr: SocketAddr = "203.0.113.50:9000".parse().unwrap();
29//! let config = RelayClientConfig::default();
30//! let client = MasqueRelayClient::connect(relay_addr, config).await?;
31//!
32//! // Get our public address
33//! let public_addr = client.public_address();
34//!
35//! // Send datagram to target through relay
36//! client.send_datagram(target_addr, data).await?;
37//! ```
38
39use bytes::Bytes;
40use std::collections::HashMap;
41use std::net::SocketAddr;
42use std::sync::Arc;
43use std::sync::atomic::{AtomicU64, Ordering};
44use std::time::{Duration, Instant};
45use tokio::sync::RwLock;
46
47use crate::VarInt;
48use crate::masque::{
49    Capsule, CompressedDatagram, CompressionAck, CompressionAssign, CompressionClose, ConnectError,
50    ConnectUdpRequest, ConnectUdpResponse, ContextManager, Datagram, UncompressedDatagram,
51};
52use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
53
54/// Configuration for the relay client
55#[derive(Debug, Clone)]
56pub struct RelayClientConfig {
57    /// Connection timeout
58    pub connect_timeout: Duration,
59    /// Session keepalive interval
60    pub keepalive_interval: Duration,
61    /// Maximum pending context registrations
62    pub max_pending_contexts: usize,
63    /// Prefer compressed contexts over uncompressed
64    pub prefer_compressed: bool,
65}
66
67impl Default for RelayClientConfig {
68    fn default() -> Self {
69        Self {
70            connect_timeout: Duration::from_secs(10),
71            keepalive_interval: Duration::from_secs(30),
72            max_pending_contexts: 50,
73            prefer_compressed: true,
74        }
75    }
76}
77
78/// State of the relay connection
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum RelayConnectionState {
81    /// Not connected
82    Disconnected,
83    /// Connection in progress
84    Connecting,
85    /// Connected and session established
86    Connected,
87    /// Connection failed
88    Failed,
89    /// Gracefully closed
90    Closed,
91}
92
93/// Statistics for the relay client
94#[derive(Debug, Default)]
95pub struct RelayClientStats {
96    /// Bytes sent through relay
97    pub bytes_sent: AtomicU64,
98    /// Bytes received through relay
99    pub bytes_received: AtomicU64,
100    /// Datagrams sent
101    pub datagrams_sent: AtomicU64,
102    /// Datagrams received
103    pub datagrams_received: AtomicU64,
104    /// Contexts registered
105    pub contexts_registered: AtomicU64,
106    /// Connection attempts
107    pub connection_attempts: AtomicU64,
108}
109
110impl RelayClientStats {
111    /// Create new statistics
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Record bytes sent
117    pub fn record_sent(&self, bytes: u64) {
118        self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
119        self.datagrams_sent.fetch_add(1, Ordering::Relaxed);
120    }
121
122    /// Record bytes received
123    pub fn record_received(&self, bytes: u64) {
124        self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
125        self.datagrams_received.fetch_add(1, Ordering::Relaxed);
126    }
127
128    /// Record a context registration
129    pub fn record_context(&self) {
130        self.contexts_registered.fetch_add(1, Ordering::Relaxed);
131    }
132
133    /// Total bytes sent
134    pub fn total_sent(&self) -> u64 {
135        self.bytes_sent.load(Ordering::Relaxed)
136    }
137
138    /// Total bytes received
139    pub fn total_received(&self) -> u64 {
140        self.bytes_received.load(Ordering::Relaxed)
141    }
142}
143
144/// Pending datagram awaiting context acknowledgement
145#[derive(Debug)]
146struct PendingDatagram {
147    /// Target address for the datagram
148    target: SocketAddr,
149    /// The datagram payload (stored for retry after ACK)
150    #[allow(dead_code)]
151    payload: Bytes,
152    /// When the datagram was queued (for timeout handling)
153    #[allow(dead_code)]
154    created_at: Instant,
155}
156
157/// MASQUE Relay Client
158///
159/// Manages a connection to a MASQUE relay server and provides
160/// APIs for sending and receiving datagrams through the relay.
161#[derive(Debug)]
162pub struct MasqueRelayClient {
163    /// Configuration
164    config: RelayClientConfig,
165    /// Relay server address
166    relay_address: SocketAddr,
167    /// Our public address as seen by the relay
168    public_address: RwLock<Option<SocketAddr>>,
169    /// Connection state
170    state: RwLock<RelayConnectionState>,
171    /// Context manager (client role - even IDs)
172    context_manager: RwLock<ContextManager>,
173    /// Mapping: target address → context ID
174    target_to_context: RwLock<HashMap<SocketAddr, VarInt>>,
175    /// Pending datagrams waiting for context ACK
176    pending_datagrams: RwLock<Vec<PendingDatagram>>,
177    /// Connection timestamp
178    connected_at: RwLock<Option<Instant>>,
179    /// Statistics
180    stats: Arc<RelayClientStats>,
181}
182
183impl MasqueRelayClient {
184    /// Create a new relay client (not yet connected)
185    pub fn new(relay_address: SocketAddr, config: RelayClientConfig) -> Self {
186        Self {
187            config,
188            relay_address,
189            public_address: RwLock::new(None),
190            state: RwLock::new(RelayConnectionState::Disconnected),
191            context_manager: RwLock::new(ContextManager::new(true)), // Client role
192            target_to_context: RwLock::new(HashMap::new()),
193            pending_datagrams: RwLock::new(Vec::new()),
194            connected_at: RwLock::new(None),
195            stats: Arc::new(RelayClientStats::new()),
196        }
197    }
198
199    /// Get relay server address
200    pub fn relay_address(&self) -> SocketAddr {
201        self.relay_address
202    }
203
204    /// Get our public address (if known)
205    pub async fn public_address(&self) -> Option<SocketAddr> {
206        *self.public_address.read().await
207    }
208
209    /// Get current connection state
210    pub async fn state(&self) -> RelayConnectionState {
211        *self.state.read().await
212    }
213
214    /// Check if connected
215    pub async fn is_connected(&self) -> bool {
216        *self.state.read().await == RelayConnectionState::Connected
217    }
218
219    /// Get connection duration
220    pub async fn connection_duration(&self) -> Option<Duration> {
221        self.connected_at.read().await.map(|t| t.elapsed())
222    }
223
224    /// Get statistics
225    pub fn stats(&self) -> Arc<RelayClientStats> {
226        Arc::clone(&self.stats)
227    }
228
229    /// Create a CONNECT-UDP Bind request
230    pub fn create_connect_request(&self) -> ConnectUdpRequest {
231        ConnectUdpRequest::bind_any()
232    }
233
234    /// Handle the CONNECT-UDP response from the relay
235    pub async fn handle_connect_response(&self, response: ConnectUdpResponse) -> RelayResult<()> {
236        if !response.is_success() {
237            *self.state.write().await = RelayConnectionState::Failed;
238            return Err(RelayError::SessionError {
239                session_id: None,
240                kind: SessionErrorKind::InvalidState {
241                    current_state: format!("HTTP {}", response.status),
242                    expected_state: "HTTP 200".into(),
243                },
244            });
245        }
246
247        // Store public address if provided
248        if let Some(addr) = response.proxy_public_address {
249            *self.public_address.write().await = Some(addr);
250            tracing::info!(
251                relay = %self.relay_address,
252                public_addr = %addr,
253                "MASQUE relay session established"
254            );
255        }
256
257        *self.state.write().await = RelayConnectionState::Connected;
258        *self.connected_at.write().await = Some(Instant::now());
259
260        Ok(())
261    }
262
263    /// Handle an incoming capsule from the relay
264    pub async fn handle_capsule(&self, capsule: Capsule) -> RelayResult<Option<Capsule>> {
265        match capsule {
266            Capsule::CompressionAck(ack) => self.handle_ack(ack).await,
267            Capsule::CompressionClose(close) => self.handle_close(close).await,
268            Capsule::CompressionAssign(assign) => self.handle_assign(assign).await,
269            Capsule::Unknown { capsule_type, .. } => {
270                tracing::debug!(
271                    capsule_type = capsule_type.into_inner(),
272                    "Ignoring unknown capsule from relay"
273                );
274                Ok(None)
275            }
276        }
277    }
278
279    /// Handle COMPRESSION_ACK from relay
280    async fn handle_ack(&self, ack: CompressionAck) -> RelayResult<Option<Capsule>> {
281        let result = {
282            let mut mgr = self.context_manager.write().await;
283            mgr.handle_ack(ack.context_id)
284        }; // Release write lock before calling flush
285
286        match result {
287            Ok(_) => {
288                self.stats.record_context();
289                tracing::debug!(
290                    context_id = ack.context_id.into_inner(),
291                    "Context acknowledged by relay"
292                );
293
294                // Try to send any pending datagrams for this context
295                self.flush_pending_for_context(ack.context_id).await;
296                Ok(None)
297            }
298            Err(e) => {
299                tracing::warn!(
300                    context_id = ack.context_id.into_inner(),
301                    error = %e,
302                    "Unexpected ACK from relay"
303                );
304                Ok(None)
305            }
306        }
307    }
308
309    /// Handle COMPRESSION_CLOSE from relay
310    async fn handle_close(&self, close: CompressionClose) -> RelayResult<Option<Capsule>> {
311        let target = {
312            let mgr = self.context_manager.read().await;
313            mgr.get_target(close.context_id)
314        };
315
316        // Remove from our mapping
317        if let Some(t) = target {
318            self.target_to_context.write().await.remove(&t);
319        }
320
321        // Close in context manager
322        let mut mgr = self.context_manager.write().await;
323        let _ = mgr.close(close.context_id);
324
325        tracing::debug!(
326            context_id = close.context_id.into_inner(),
327            "Context closed by relay"
328        );
329
330        Ok(None)
331    }
332
333    /// Handle COMPRESSION_ASSIGN from relay (relay allocating context)
334    async fn handle_assign(&self, assign: CompressionAssign) -> RelayResult<Option<Capsule>> {
335        let target = assign.target();
336
337        // Register the remote context
338        {
339            let mut mgr = self.context_manager.write().await;
340            if let Err(e) = mgr.register_remote(assign.context_id, target) {
341                tracing::warn!(
342                    context_id = assign.context_id.into_inner(),
343                    error = %e,
344                    "Failed to register remote context"
345                );
346                // Send CLOSE to reject
347                return Ok(Some(Capsule::CompressionClose(CompressionClose::new(
348                    assign.context_id,
349                ))));
350            }
351        }
352
353        // Update target mapping
354        if let Some(t) = target {
355            self.target_to_context
356                .write()
357                .await
358                .insert(t, assign.context_id);
359        }
360
361        // Send ACK
362        Ok(Some(Capsule::CompressionAck(CompressionAck::new(
363            assign.context_id,
364        ))))
365    }
366
367    /// Get or create a context for a target address
368    ///
369    /// Returns the context ID and an optional capsule to send (COMPRESSION_ASSIGN).
370    pub async fn get_or_create_context(
371        &self,
372        target: SocketAddr,
373    ) -> RelayResult<(VarInt, Option<Capsule>)> {
374        // Check if we already have a context
375        {
376            let map = self.target_to_context.read().await;
377            if let Some(&ctx_id) = map.get(&target) {
378                let mgr = self.context_manager.read().await;
379                if let Some(info) = mgr.get_context(ctx_id) {
380                    if info.state == crate::masque::ContextState::Active {
381                        return Ok((ctx_id, None));
382                    }
383                }
384            }
385        }
386
387        // Allocate new context
388        let ctx_id = {
389            let mut mgr = self.context_manager.write().await;
390            let id = mgr
391                .allocate_local()
392                .map_err(|_| RelayError::ResourceExhausted {
393                    resource_type: "contexts".into(),
394                    current_usage: mgr.active_count() as u64,
395                    limit: self.config.max_pending_contexts as u64,
396                })?;
397
398            // Register as compressed context
399            mgr.register_compressed(id, target)
400                .map_err(|_| RelayError::SessionError {
401                    session_id: None,
402                    kind: SessionErrorKind::InvalidState {
403                        current_state: "duplicate target".into(),
404                        expected_state: "unique target".into(),
405                    },
406                })?;
407
408            id
409        };
410
411        // Add to target map (as pending)
412        self.target_to_context.write().await.insert(target, ctx_id);
413
414        // Create COMPRESSION_ASSIGN capsule
415        let assign = match target {
416            SocketAddr::V4(v4) => CompressionAssign::compressed_v4(ctx_id, *v4.ip(), v4.port()),
417            SocketAddr::V6(v6) => CompressionAssign::compressed_v6(ctx_id, *v6.ip(), v6.port()),
418        };
419
420        Ok((ctx_id, Some(Capsule::CompressionAssign(assign))))
421    }
422
423    /// Create a datagram for sending to a target
424    ///
425    /// If a context exists and is active, returns a compressed datagram.
426    /// Otherwise returns an uncompressed datagram (if allowed).
427    pub async fn create_datagram(
428        &self,
429        target: SocketAddr,
430        payload: Bytes,
431    ) -> RelayResult<(Datagram, Option<Capsule>)> {
432        // Try to get existing active context
433        {
434            let map = self.target_to_context.read().await;
435            if let Some(&ctx_id) = map.get(&target) {
436                let mgr = self.context_manager.read().await;
437                if let Some(info) = mgr.get_context(ctx_id) {
438                    if info.state == crate::masque::ContextState::Active {
439                        // Use compressed datagram
440                        let datagram = CompressedDatagram::new(ctx_id, payload);
441                        return Ok((Datagram::Compressed(datagram), None));
442                    }
443                }
444            }
445        }
446
447        // Create new context (always needed for both compressed and uncompressed)
448        let (ctx_id, capsule) = self.get_or_create_context(target).await?;
449
450        // Context is pending - queue the datagram
451        if capsule.is_some() {
452            self.pending_datagrams.write().await.push(PendingDatagram {
453                target,
454                payload: payload.clone(),
455                created_at: Instant::now(),
456            });
457        }
458
459        // Return compressed datagram (caller should send capsule first if returned)
460        let datagram = CompressedDatagram::new(ctx_id, payload);
461        Ok((Datagram::Compressed(datagram), capsule))
462    }
463
464    /// Flush pending datagrams for a context
465    async fn flush_pending_for_context(&self, ctx_id: VarInt) {
466        let target = {
467            let mgr = self.context_manager.read().await;
468            mgr.get_target(ctx_id)
469        };
470
471        if let Some(target) = target {
472            let mut pending = self.pending_datagrams.write().await;
473            pending.retain(|d| d.target != target);
474        }
475    }
476
477    /// Decode an incoming datagram from the relay
478    pub async fn decode_datagram(&self, data: &[u8]) -> RelayResult<(SocketAddr, Bytes)> {
479        // Try to decode as compressed first (more common)
480        if let Ok(datagram) = CompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data)) {
481            let mgr = self.context_manager.read().await;
482            if let Some(target) = mgr.get_target(datagram.context_id) {
483                self.stats.record_received(datagram.payload.len() as u64);
484                return Ok((target, datagram.payload));
485            }
486        }
487
488        // Try uncompressed
489        if let Ok(datagram) = UncompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data))
490        {
491            self.stats.record_received(datagram.payload.len() as u64);
492            return Ok((datagram.target, datagram.payload));
493        }
494
495        Err(RelayError::ProtocolError {
496            frame_type: 0,
497            reason: "Failed to decode datagram".into(),
498        })
499    }
500
501    /// Record a sent datagram
502    pub fn record_sent(&self, bytes: usize) {
503        self.stats.record_sent(bytes as u64);
504    }
505
506    /// Close the relay connection
507    pub async fn close(&self) {
508        *self.state.write().await = RelayConnectionState::Closed;
509
510        // Clear all contexts
511        self.target_to_context.write().await.clear();
512        self.pending_datagrams.write().await.clear();
513
514        tracing::info!(
515            relay = %self.relay_address,
516            "MASQUE relay client closed"
517        );
518    }
519
520    /// Get list of active context IDs
521    pub async fn active_contexts(&self) -> Vec<VarInt> {
522        let mgr = self.context_manager.read().await;
523        mgr.local_context_ids().collect()
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use std::net::{IpAddr, Ipv4Addr};
531
532    fn test_addr(port: u16) -> SocketAddr {
533        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
534    }
535
536    fn relay_addr() -> SocketAddr {
537        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000)
538    }
539
540    #[tokio::test]
541    async fn test_client_creation() {
542        let config = RelayClientConfig::default();
543        let client = MasqueRelayClient::new(relay_addr(), config);
544
545        assert_eq!(client.relay_address(), relay_addr());
546        assert!(!client.is_connected().await);
547        assert!(client.public_address().await.is_none());
548    }
549
550    #[tokio::test]
551    async fn test_connect_request() {
552        let config = RelayClientConfig::default();
553        let client = MasqueRelayClient::new(relay_addr(), config);
554
555        let request = client.create_connect_request();
556        assert!(request.connect_udp_bind);
557    }
558
559    #[tokio::test]
560    async fn test_handle_success_response() {
561        let config = RelayClientConfig::default();
562        let client = MasqueRelayClient::new(relay_addr(), config);
563
564        let public_addr = test_addr(12345);
565        let response = ConnectUdpResponse::success(Some(public_addr));
566
567        client.handle_connect_response(response).await.unwrap();
568
569        assert!(client.is_connected().await);
570        assert_eq!(client.public_address().await, Some(public_addr));
571    }
572
573    #[tokio::test]
574    async fn test_handle_error_response() {
575        let config = RelayClientConfig::default();
576        let client = MasqueRelayClient::new(relay_addr(), config);
577
578        let response = ConnectUdpResponse::error(503, "Server busy");
579
580        let result = client.handle_connect_response(response).await;
581        assert!(result.is_err());
582        assert_eq!(client.state().await, RelayConnectionState::Failed);
583    }
584
585    #[tokio::test]
586    async fn test_context_creation() {
587        let config = RelayClientConfig::default();
588        let client = MasqueRelayClient::new(relay_addr(), config);
589
590        // Simulate connected state
591        let response = ConnectUdpResponse::success(Some(test_addr(12345)));
592        client.handle_connect_response(response).await.unwrap();
593
594        let target = test_addr(8080);
595        let (ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
596
597        // First call should return a capsule (COMPRESSION_ASSIGN)
598        assert!(capsule.is_some());
599        assert!(matches!(capsule, Some(Capsule::CompressionAssign(_))));
600
601        // Context should use even ID (client)
602        assert_eq!(ctx_id.into_inner() % 2, 0);
603    }
604
605    #[tokio::test]
606    async fn test_handle_compression_ack() {
607        let config = RelayClientConfig::default();
608        let client = MasqueRelayClient::new(relay_addr(), config);
609
610        let response = ConnectUdpResponse::success(Some(test_addr(12345)));
611        client.handle_connect_response(response).await.unwrap();
612
613        let target = test_addr(8080);
614        let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
615
616        // Handle ACK
617        let ack = CompressionAck::new(ctx_id);
618        let result = client.handle_capsule(Capsule::CompressionAck(ack)).await;
619        assert!(result.is_ok());
620        assert!(result.unwrap().is_none());
621
622        // Now context should be active
623        let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
624        assert_eq!(new_ctx_id, ctx_id);
625        assert!(capsule.is_none()); // No new assignment needed
626    }
627
628    #[tokio::test]
629    async fn test_handle_compression_close() {
630        let config = RelayClientConfig::default();
631        let client = MasqueRelayClient::new(relay_addr(), config);
632
633        let response = ConnectUdpResponse::success(Some(test_addr(12345)));
634        client.handle_connect_response(response).await.unwrap();
635
636        let target = test_addr(8080);
637        let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
638
639        // Simulate ACK
640        let ack = CompressionAck::new(ctx_id);
641        client
642            .handle_capsule(Capsule::CompressionAck(ack))
643            .await
644            .unwrap();
645
646        // Handle CLOSE
647        let close = CompressionClose::new(ctx_id);
648        let result = client
649            .handle_capsule(Capsule::CompressionClose(close))
650            .await;
651        assert!(result.is_ok());
652
653        // Context should be removed
654        let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
655        assert_ne!(new_ctx_id, ctx_id); // New context ID
656        assert!(capsule.is_some()); // New assignment needed
657    }
658
659    #[tokio::test]
660    async fn test_create_datagram_compressed() {
661        let config = RelayClientConfig {
662            prefer_compressed: true,
663            ..Default::default()
664        };
665        let client = MasqueRelayClient::new(relay_addr(), config);
666
667        let response = ConnectUdpResponse::success(Some(test_addr(12345)));
668        client.handle_connect_response(response).await.unwrap();
669
670        let target = test_addr(8080);
671        let payload = Bytes::from("Hello, relay!");
672
673        let (datagram, capsule) = client.create_datagram(target, payload).await.unwrap();
674
675        // Should create compressed datagram with assignment
676        assert!(matches!(datagram, Datagram::Compressed(_)));
677        assert!(capsule.is_some());
678    }
679
680    #[tokio::test]
681    async fn test_client_close() {
682        let config = RelayClientConfig::default();
683        let client = MasqueRelayClient::new(relay_addr(), config);
684
685        let response = ConnectUdpResponse::success(Some(test_addr(12345)));
686        client.handle_connect_response(response).await.unwrap();
687        assert!(client.is_connected().await);
688
689        client.close().await;
690        assert_eq!(client.state().await, RelayConnectionState::Closed);
691    }
692
693    #[tokio::test]
694    async fn test_stats() {
695        let config = RelayClientConfig::default();
696        let client = MasqueRelayClient::new(relay_addr(), config);
697
698        let stats = client.stats();
699        assert_eq!(stats.total_sent(), 0);
700        assert_eq!(stats.total_received(), 0);
701
702        client.record_sent(100);
703        assert_eq!(stats.total_sent(), 100);
704        assert_eq!(stats.datagrams_sent.load(Ordering::Relaxed), 1);
705    }
706}