1use bytes::Bytes;
40use std::collections::HashMap;
41use std::net::SocketAddr;
42use std::sync::Arc;
43use std::sync::atomic::{AtomicU64, Ordering};
44use std::time::{Duration, Instant};
45use tokio::sync::RwLock;
46
47use crate::VarInt;
48use crate::masque::{
49 Capsule, CompressedDatagram, CompressionAck, CompressionAssign, CompressionClose, ConnectError,
50 ConnectUdpRequest, ConnectUdpResponse, ContextManager, Datagram, UncompressedDatagram,
51};
52use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
53
54#[derive(Debug, Clone)]
56pub struct RelayClientConfig {
57 pub connect_timeout: Duration,
59 pub keepalive_interval: Duration,
61 pub max_pending_contexts: usize,
63 pub prefer_compressed: bool,
65}
66
67impl Default for RelayClientConfig {
68 fn default() -> Self {
69 Self {
70 connect_timeout: Duration::from_secs(10),
71 keepalive_interval: Duration::from_secs(30),
72 max_pending_contexts: 50,
73 prefer_compressed: true,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum RelayConnectionState {
81 Disconnected,
83 Connecting,
85 Connected,
87 Failed,
89 Closed,
91}
92
93#[derive(Debug, Default)]
95pub struct RelayClientStats {
96 pub bytes_sent: AtomicU64,
98 pub bytes_received: AtomicU64,
100 pub datagrams_sent: AtomicU64,
102 pub datagrams_received: AtomicU64,
104 pub contexts_registered: AtomicU64,
106 pub connection_attempts: AtomicU64,
108}
109
110impl RelayClientStats {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn record_sent(&self, bytes: u64) {
118 self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
119 self.datagrams_sent.fetch_add(1, Ordering::Relaxed);
120 }
121
122 pub fn record_received(&self, bytes: u64) {
124 self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
125 self.datagrams_received.fetch_add(1, Ordering::Relaxed);
126 }
127
128 pub fn record_context(&self) {
130 self.contexts_registered.fetch_add(1, Ordering::Relaxed);
131 }
132
133 pub fn total_sent(&self) -> u64 {
135 self.bytes_sent.load(Ordering::Relaxed)
136 }
137
138 pub fn total_received(&self) -> u64 {
140 self.bytes_received.load(Ordering::Relaxed)
141 }
142}
143
144#[derive(Debug)]
146struct PendingDatagram {
147 target: SocketAddr,
149 #[allow(dead_code)]
151 payload: Bytes,
152 #[allow(dead_code)]
154 created_at: Instant,
155}
156
157#[derive(Debug)]
162pub struct MasqueRelayClient {
163 config: RelayClientConfig,
165 relay_address: SocketAddr,
167 public_address: RwLock<Option<SocketAddr>>,
169 state: RwLock<RelayConnectionState>,
171 context_manager: RwLock<ContextManager>,
173 target_to_context: RwLock<HashMap<SocketAddr, VarInt>>,
175 pending_datagrams: RwLock<Vec<PendingDatagram>>,
177 connected_at: RwLock<Option<Instant>>,
179 stats: Arc<RelayClientStats>,
181}
182
183impl MasqueRelayClient {
184 pub fn new(relay_address: SocketAddr, config: RelayClientConfig) -> Self {
186 Self {
187 config,
188 relay_address,
189 public_address: RwLock::new(None),
190 state: RwLock::new(RelayConnectionState::Disconnected),
191 context_manager: RwLock::new(ContextManager::new(true)), target_to_context: RwLock::new(HashMap::new()),
193 pending_datagrams: RwLock::new(Vec::new()),
194 connected_at: RwLock::new(None),
195 stats: Arc::new(RelayClientStats::new()),
196 }
197 }
198
199 pub fn relay_address(&self) -> SocketAddr {
201 self.relay_address
202 }
203
204 pub async fn public_address(&self) -> Option<SocketAddr> {
206 *self.public_address.read().await
207 }
208
209 pub async fn state(&self) -> RelayConnectionState {
211 *self.state.read().await
212 }
213
214 pub async fn is_connected(&self) -> bool {
216 *self.state.read().await == RelayConnectionState::Connected
217 }
218
219 pub async fn connection_duration(&self) -> Option<Duration> {
221 self.connected_at.read().await.map(|t| t.elapsed())
222 }
223
224 pub fn stats(&self) -> Arc<RelayClientStats> {
226 Arc::clone(&self.stats)
227 }
228
229 pub fn create_connect_request(&self) -> ConnectUdpRequest {
231 ConnectUdpRequest::bind_any()
232 }
233
234 pub async fn handle_connect_response(&self, response: ConnectUdpResponse) -> RelayResult<()> {
236 if !response.is_success() {
237 *self.state.write().await = RelayConnectionState::Failed;
238 return Err(RelayError::SessionError {
239 session_id: None,
240 kind: SessionErrorKind::InvalidState {
241 current_state: format!("HTTP {}", response.status),
242 expected_state: "HTTP 200".into(),
243 },
244 });
245 }
246
247 if let Some(addr) = response.proxy_public_address {
249 *self.public_address.write().await = Some(addr);
250 tracing::info!(
251 relay = %self.relay_address,
252 public_addr = %addr,
253 "MASQUE relay session established"
254 );
255 }
256
257 *self.state.write().await = RelayConnectionState::Connected;
258 *self.connected_at.write().await = Some(Instant::now());
259
260 Ok(())
261 }
262
263 pub async fn handle_capsule(&self, capsule: Capsule) -> RelayResult<Option<Capsule>> {
265 match capsule {
266 Capsule::CompressionAck(ack) => self.handle_ack(ack).await,
267 Capsule::CompressionClose(close) => self.handle_close(close).await,
268 Capsule::CompressionAssign(assign) => self.handle_assign(assign).await,
269 Capsule::Unknown { capsule_type, .. } => {
270 tracing::debug!(
271 capsule_type = capsule_type.into_inner(),
272 "Ignoring unknown capsule from relay"
273 );
274 Ok(None)
275 }
276 }
277 }
278
279 async fn handle_ack(&self, ack: CompressionAck) -> RelayResult<Option<Capsule>> {
281 let result = {
282 let mut mgr = self.context_manager.write().await;
283 mgr.handle_ack(ack.context_id)
284 }; match result {
287 Ok(_) => {
288 self.stats.record_context();
289 tracing::debug!(
290 context_id = ack.context_id.into_inner(),
291 "Context acknowledged by relay"
292 );
293
294 self.flush_pending_for_context(ack.context_id).await;
296 Ok(None)
297 }
298 Err(e) => {
299 tracing::warn!(
300 context_id = ack.context_id.into_inner(),
301 error = %e,
302 "Unexpected ACK from relay"
303 );
304 Ok(None)
305 }
306 }
307 }
308
309 async fn handle_close(&self, close: CompressionClose) -> RelayResult<Option<Capsule>> {
311 let target = {
312 let mgr = self.context_manager.read().await;
313 mgr.get_target(close.context_id)
314 };
315
316 if let Some(t) = target {
318 self.target_to_context.write().await.remove(&t);
319 }
320
321 let mut mgr = self.context_manager.write().await;
323 let _ = mgr.close(close.context_id);
324
325 tracing::debug!(
326 context_id = close.context_id.into_inner(),
327 "Context closed by relay"
328 );
329
330 Ok(None)
331 }
332
333 async fn handle_assign(&self, assign: CompressionAssign) -> RelayResult<Option<Capsule>> {
335 let target = assign.target();
336
337 {
339 let mut mgr = self.context_manager.write().await;
340 if let Err(e) = mgr.register_remote(assign.context_id, target) {
341 tracing::warn!(
342 context_id = assign.context_id.into_inner(),
343 error = %e,
344 "Failed to register remote context"
345 );
346 return Ok(Some(Capsule::CompressionClose(CompressionClose::new(
348 assign.context_id,
349 ))));
350 }
351 }
352
353 if let Some(t) = target {
355 self.target_to_context
356 .write()
357 .await
358 .insert(t, assign.context_id);
359 }
360
361 Ok(Some(Capsule::CompressionAck(CompressionAck::new(
363 assign.context_id,
364 ))))
365 }
366
367 pub async fn get_or_create_context(
371 &self,
372 target: SocketAddr,
373 ) -> RelayResult<(VarInt, Option<Capsule>)> {
374 {
376 let map = self.target_to_context.read().await;
377 if let Some(&ctx_id) = map.get(&target) {
378 let mgr = self.context_manager.read().await;
379 if let Some(info) = mgr.get_context(ctx_id) {
380 if info.state == crate::masque::ContextState::Active {
381 return Ok((ctx_id, None));
382 }
383 }
384 }
385 }
386
387 let ctx_id = {
389 let mut mgr = self.context_manager.write().await;
390 let id = mgr
391 .allocate_local()
392 .map_err(|_| RelayError::ResourceExhausted {
393 resource_type: "contexts".into(),
394 current_usage: mgr.active_count() as u64,
395 limit: self.config.max_pending_contexts as u64,
396 })?;
397
398 mgr.register_compressed(id, target)
400 .map_err(|_| RelayError::SessionError {
401 session_id: None,
402 kind: SessionErrorKind::InvalidState {
403 current_state: "duplicate target".into(),
404 expected_state: "unique target".into(),
405 },
406 })?;
407
408 id
409 };
410
411 self.target_to_context.write().await.insert(target, ctx_id);
413
414 let assign = match target {
416 SocketAddr::V4(v4) => CompressionAssign::compressed_v4(ctx_id, *v4.ip(), v4.port()),
417 SocketAddr::V6(v6) => CompressionAssign::compressed_v6(ctx_id, *v6.ip(), v6.port()),
418 };
419
420 Ok((ctx_id, Some(Capsule::CompressionAssign(assign))))
421 }
422
423 pub async fn create_datagram(
428 &self,
429 target: SocketAddr,
430 payload: Bytes,
431 ) -> RelayResult<(Datagram, Option<Capsule>)> {
432 {
434 let map = self.target_to_context.read().await;
435 if let Some(&ctx_id) = map.get(&target) {
436 let mgr = self.context_manager.read().await;
437 if let Some(info) = mgr.get_context(ctx_id) {
438 if info.state == crate::masque::ContextState::Active {
439 let datagram = CompressedDatagram::new(ctx_id, payload);
441 return Ok((Datagram::Compressed(datagram), None));
442 }
443 }
444 }
445 }
446
447 let (ctx_id, capsule) = self.get_or_create_context(target).await?;
449
450 if capsule.is_some() {
452 self.pending_datagrams.write().await.push(PendingDatagram {
453 target,
454 payload: payload.clone(),
455 created_at: Instant::now(),
456 });
457 }
458
459 let datagram = CompressedDatagram::new(ctx_id, payload);
461 Ok((Datagram::Compressed(datagram), capsule))
462 }
463
464 async fn flush_pending_for_context(&self, ctx_id: VarInt) {
466 let target = {
467 let mgr = self.context_manager.read().await;
468 mgr.get_target(ctx_id)
469 };
470
471 if let Some(target) = target {
472 let mut pending = self.pending_datagrams.write().await;
473 pending.retain(|d| d.target != target);
474 }
475 }
476
477 pub async fn decode_datagram(&self, data: &[u8]) -> RelayResult<(SocketAddr, Bytes)> {
479 if let Ok(datagram) = CompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data)) {
481 let mgr = self.context_manager.read().await;
482 if let Some(target) = mgr.get_target(datagram.context_id) {
483 self.stats.record_received(datagram.payload.len() as u64);
484 return Ok((target, datagram.payload));
485 }
486 }
487
488 if let Ok(datagram) = UncompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data))
490 {
491 self.stats.record_received(datagram.payload.len() as u64);
492 return Ok((datagram.target, datagram.payload));
493 }
494
495 Err(RelayError::ProtocolError {
496 frame_type: 0,
497 reason: "Failed to decode datagram".into(),
498 })
499 }
500
501 pub fn record_sent(&self, bytes: usize) {
503 self.stats.record_sent(bytes as u64);
504 }
505
506 pub async fn close(&self) {
508 *self.state.write().await = RelayConnectionState::Closed;
509
510 self.target_to_context.write().await.clear();
512 self.pending_datagrams.write().await.clear();
513
514 tracing::info!(
515 relay = %self.relay_address,
516 "MASQUE relay client closed"
517 );
518 }
519
520 pub async fn active_contexts(&self) -> Vec<VarInt> {
522 let mgr = self.context_manager.read().await;
523 mgr.local_context_ids().collect()
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use std::net::{IpAddr, Ipv4Addr};
531
532 fn test_addr(port: u16) -> SocketAddr {
533 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
534 }
535
536 fn relay_addr() -> SocketAddr {
537 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000)
538 }
539
540 #[tokio::test]
541 async fn test_client_creation() {
542 let config = RelayClientConfig::default();
543 let client = MasqueRelayClient::new(relay_addr(), config);
544
545 assert_eq!(client.relay_address(), relay_addr());
546 assert!(!client.is_connected().await);
547 assert!(client.public_address().await.is_none());
548 }
549
550 #[tokio::test]
551 async fn test_connect_request() {
552 let config = RelayClientConfig::default();
553 let client = MasqueRelayClient::new(relay_addr(), config);
554
555 let request = client.create_connect_request();
556 assert!(request.connect_udp_bind);
557 }
558
559 #[tokio::test]
560 async fn test_handle_success_response() {
561 let config = RelayClientConfig::default();
562 let client = MasqueRelayClient::new(relay_addr(), config);
563
564 let public_addr = test_addr(12345);
565 let response = ConnectUdpResponse::success(Some(public_addr));
566
567 client.handle_connect_response(response).await.unwrap();
568
569 assert!(client.is_connected().await);
570 assert_eq!(client.public_address().await, Some(public_addr));
571 }
572
573 #[tokio::test]
574 async fn test_handle_error_response() {
575 let config = RelayClientConfig::default();
576 let client = MasqueRelayClient::new(relay_addr(), config);
577
578 let response = ConnectUdpResponse::error(503, "Server busy");
579
580 let result = client.handle_connect_response(response).await;
581 assert!(result.is_err());
582 assert_eq!(client.state().await, RelayConnectionState::Failed);
583 }
584
585 #[tokio::test]
586 async fn test_context_creation() {
587 let config = RelayClientConfig::default();
588 let client = MasqueRelayClient::new(relay_addr(), config);
589
590 let response = ConnectUdpResponse::success(Some(test_addr(12345)));
592 client.handle_connect_response(response).await.unwrap();
593
594 let target = test_addr(8080);
595 let (ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
596
597 assert!(capsule.is_some());
599 assert!(matches!(capsule, Some(Capsule::CompressionAssign(_))));
600
601 assert_eq!(ctx_id.into_inner() % 2, 0);
603 }
604
605 #[tokio::test]
606 async fn test_handle_compression_ack() {
607 let config = RelayClientConfig::default();
608 let client = MasqueRelayClient::new(relay_addr(), config);
609
610 let response = ConnectUdpResponse::success(Some(test_addr(12345)));
611 client.handle_connect_response(response).await.unwrap();
612
613 let target = test_addr(8080);
614 let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
615
616 let ack = CompressionAck::new(ctx_id);
618 let result = client.handle_capsule(Capsule::CompressionAck(ack)).await;
619 assert!(result.is_ok());
620 assert!(result.unwrap().is_none());
621
622 let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
624 assert_eq!(new_ctx_id, ctx_id);
625 assert!(capsule.is_none()); }
627
628 #[tokio::test]
629 async fn test_handle_compression_close() {
630 let config = RelayClientConfig::default();
631 let client = MasqueRelayClient::new(relay_addr(), config);
632
633 let response = ConnectUdpResponse::success(Some(test_addr(12345)));
634 client.handle_connect_response(response).await.unwrap();
635
636 let target = test_addr(8080);
637 let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
638
639 let ack = CompressionAck::new(ctx_id);
641 client
642 .handle_capsule(Capsule::CompressionAck(ack))
643 .await
644 .unwrap();
645
646 let close = CompressionClose::new(ctx_id);
648 let result = client
649 .handle_capsule(Capsule::CompressionClose(close))
650 .await;
651 assert!(result.is_ok());
652
653 let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
655 assert_ne!(new_ctx_id, ctx_id); assert!(capsule.is_some()); }
658
659 #[tokio::test]
660 async fn test_create_datagram_compressed() {
661 let config = RelayClientConfig {
662 prefer_compressed: true,
663 ..Default::default()
664 };
665 let client = MasqueRelayClient::new(relay_addr(), config);
666
667 let response = ConnectUdpResponse::success(Some(test_addr(12345)));
668 client.handle_connect_response(response).await.unwrap();
669
670 let target = test_addr(8080);
671 let payload = Bytes::from("Hello, relay!");
672
673 let (datagram, capsule) = client.create_datagram(target, payload).await.unwrap();
674
675 assert!(matches!(datagram, Datagram::Compressed(_)));
677 assert!(capsule.is_some());
678 }
679
680 #[tokio::test]
681 async fn test_client_close() {
682 let config = RelayClientConfig::default();
683 let client = MasqueRelayClient::new(relay_addr(), config);
684
685 let response = ConnectUdpResponse::success(Some(test_addr(12345)));
686 client.handle_connect_response(response).await.unwrap();
687 assert!(client.is_connected().await);
688
689 client.close().await;
690 assert_eq!(client.state().await, RelayConnectionState::Closed);
691 }
692
693 #[tokio::test]
694 async fn test_stats() {
695 let config = RelayClientConfig::default();
696 let client = MasqueRelayClient::new(relay_addr(), config);
697
698 let stats = client.stats();
699 assert_eq!(stats.total_sent(), 0);
700 assert_eq!(stats.total_received(), 0);
701
702 client.record_sent(100);
703 assert_eq!(stats.total_sent(), 100);
704 assert_eq!(stats.datagrams_sent.load(Ordering::Relaxed), 1);
705 }
706}