use bytes::Bytes;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::VarInt;
use crate::masque::{
Capsule, CompressedDatagram, CompressionAck, CompressionAssign, CompressionClose,
ConnectUdpRequest, ConnectUdpResponse, ContextManager, Datagram, UncompressedDatagram,
};
use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
#[derive(Debug, Clone)]
pub struct RelayClientConfig {
pub connect_timeout: Duration,
pub keepalive_interval: Duration,
pub max_pending_contexts: usize,
pub prefer_compressed: bool,
}
impl Default for RelayClientConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(10),
keepalive_interval: Duration::from_secs(30),
max_pending_contexts: 50,
prefer_compressed: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelayConnectionState {
Disconnected,
Connecting,
Connected,
Failed,
Closed,
}
#[derive(Debug, Default)]
pub struct RelayClientStats {
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub datagrams_sent: AtomicU64,
pub datagrams_received: AtomicU64,
pub contexts_registered: AtomicU64,
pub connection_attempts: AtomicU64,
}
impl RelayClientStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_sent(&self, bytes: u64) {
self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
self.datagrams_sent.fetch_add(1, Ordering::Relaxed);
}
pub fn record_received(&self, bytes: u64) {
self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
self.datagrams_received.fetch_add(1, Ordering::Relaxed);
}
pub fn record_context(&self) {
self.contexts_registered.fetch_add(1, Ordering::Relaxed);
}
pub fn total_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
pub fn total_received(&self) -> u64 {
self.bytes_received.load(Ordering::Relaxed)
}
}
const PENDING_DATAGRAM_MAX_AGE: Duration = Duration::from_secs(10);
#[derive(Debug)]
struct PendingDatagram {
target: SocketAddr,
payload: Bytes,
created_at: Instant,
}
#[derive(Debug)]
pub struct MasqueRelayClient {
config: RelayClientConfig,
relay_address: SocketAddr,
public_address: RwLock<Option<SocketAddr>>,
state: RwLock<RelayConnectionState>,
context_manager: RwLock<ContextManager>,
target_to_context: RwLock<HashMap<SocketAddr, VarInt>>,
pending_datagrams: RwLock<Vec<PendingDatagram>>,
connected_at: RwLock<Option<Instant>>,
stats: Arc<RelayClientStats>,
}
impl MasqueRelayClient {
pub fn new(relay_address: SocketAddr, config: RelayClientConfig) -> Self {
Self {
config,
relay_address,
public_address: RwLock::new(None),
state: RwLock::new(RelayConnectionState::Disconnected),
context_manager: RwLock::new(ContextManager::new(true)), target_to_context: RwLock::new(HashMap::new()),
pending_datagrams: RwLock::new(Vec::new()),
connected_at: RwLock::new(None),
stats: Arc::new(RelayClientStats::new()),
}
}
pub fn relay_address(&self) -> SocketAddr {
self.relay_address
}
pub async fn public_address(&self) -> Option<SocketAddr> {
*self.public_address.read().await
}
pub async fn state(&self) -> RelayConnectionState {
*self.state.read().await
}
pub async fn is_connected(&self) -> bool {
*self.state.read().await == RelayConnectionState::Connected
}
pub async fn connection_duration(&self) -> Option<Duration> {
self.connected_at.read().await.map(|t| t.elapsed())
}
pub fn stats(&self) -> Arc<RelayClientStats> {
Arc::clone(&self.stats)
}
pub fn create_connect_request(&self) -> ConnectUdpRequest {
ConnectUdpRequest::bind_any()
}
pub async fn handle_connect_response(&self, response: ConnectUdpResponse) -> RelayResult<()> {
if !response.is_success() {
*self.state.write().await = RelayConnectionState::Failed;
return Err(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::InvalidState {
current_state: format!("HTTP {}", response.status),
expected_state: "HTTP 200".into(),
},
});
}
if let Some(addr) = response.proxy_public_address {
*self.public_address.write().await = Some(addr);
tracing::info!(
relay = %self.relay_address,
public_addr = %addr,
"MASQUE relay session established"
);
}
*self.state.write().await = RelayConnectionState::Connected;
*self.connected_at.write().await = Some(Instant::now());
Ok(())
}
pub async fn handle_capsule(&self, capsule: Capsule) -> RelayResult<Option<Capsule>> {
match capsule {
Capsule::CompressionAck(ack) => self.handle_ack(ack).await,
Capsule::CompressionClose(close) => self.handle_close(close).await,
Capsule::CompressionAssign(assign) => self.handle_assign(assign).await,
Capsule::Unknown { capsule_type, .. } => {
tracing::debug!(
capsule_type = capsule_type.into_inner(),
"Ignoring unknown capsule from relay"
);
Ok(None)
}
}
}
async fn handle_ack(&self, ack: CompressionAck) -> RelayResult<Option<Capsule>> {
let result = {
let mut mgr = self.context_manager.write().await;
mgr.handle_ack(ack.context_id)
};
match result {
Ok(_) => {
self.stats.record_context();
tracing::debug!(
context_id = ack.context_id.into_inner(),
"Context acknowledged by relay"
);
let flushed_payloads = self.flush_pending_for_context(ack.context_id).await;
if !flushed_payloads.is_empty() {
tracing::debug!(
context_id = ack.context_id.into_inner(),
count = flushed_payloads.len(),
"Flushed pending datagrams for acknowledged context"
);
}
Ok(None)
}
Err(e) => {
tracing::warn!(
context_id = ack.context_id.into_inner(),
error = %e,
"Unexpected ACK from relay"
);
Ok(None)
}
}
}
async fn handle_close(&self, close: CompressionClose) -> RelayResult<Option<Capsule>> {
let target = {
let mgr = self.context_manager.read().await;
mgr.get_target(close.context_id)
};
if let Some(t) = target {
self.target_to_context.write().await.remove(&t);
}
let mut mgr = self.context_manager.write().await;
let _ = mgr.close(close.context_id);
tracing::debug!(
context_id = close.context_id.into_inner(),
"Context closed by relay"
);
Ok(None)
}
async fn handle_assign(&self, assign: CompressionAssign) -> RelayResult<Option<Capsule>> {
let target = assign.target();
{
let mut mgr = self.context_manager.write().await;
if let Err(e) = mgr.register_remote(assign.context_id, target) {
tracing::warn!(
context_id = assign.context_id.into_inner(),
error = %e,
"Failed to register remote context"
);
return Ok(Some(Capsule::CompressionClose(CompressionClose::new(
assign.context_id,
))));
}
}
if let Some(t) = target {
self.target_to_context
.write()
.await
.insert(t, assign.context_id);
}
Ok(Some(Capsule::CompressionAck(CompressionAck::new(
assign.context_id,
))))
}
pub async fn get_or_create_context(
&self,
target: SocketAddr,
) -> RelayResult<(VarInt, Option<Capsule>)> {
{
let map = self.target_to_context.read().await;
if let Some(&ctx_id) = map.get(&target) {
let mgr = self.context_manager.read().await;
if let Some(info) = mgr.get_context(ctx_id) {
if info.state == crate::masque::ContextState::Active {
return Ok((ctx_id, None));
}
}
}
}
let ctx_id = {
let mut mgr = self.context_manager.write().await;
let id = mgr
.allocate_local()
.map_err(|_| RelayError::ResourceExhausted {
resource_type: "contexts".into(),
current_usage: mgr.active_count() as u64,
limit: self.config.max_pending_contexts as u64,
})?;
mgr.register_compressed(id, target)
.map_err(|_| RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::InvalidState {
current_state: "duplicate target".into(),
expected_state: "unique target".into(),
},
})?;
id
};
self.target_to_context.write().await.insert(target, ctx_id);
let assign = match target {
SocketAddr::V4(v4) => CompressionAssign::compressed_v4(ctx_id, *v4.ip(), v4.port()),
SocketAddr::V6(v6) => CompressionAssign::compressed_v6(ctx_id, *v6.ip(), v6.port()),
};
Ok((ctx_id, Some(Capsule::CompressionAssign(assign))))
}
pub async fn create_datagram(
&self,
target: SocketAddr,
payload: Bytes,
) -> RelayResult<(Datagram, Option<Capsule>)> {
{
let map = self.target_to_context.read().await;
if let Some(&ctx_id) = map.get(&target) {
let mgr = self.context_manager.read().await;
if let Some(info) = mgr.get_context(ctx_id) {
if info.state == crate::masque::ContextState::Active {
let datagram = CompressedDatagram::new(ctx_id, payload);
return Ok((Datagram::Compressed(datagram), None));
}
}
}
}
let (ctx_id, capsule) = self.get_or_create_context(target).await?;
if capsule.is_some() {
self.pending_datagrams.write().await.push(PendingDatagram {
target,
payload: payload.clone(),
created_at: Instant::now(),
});
}
let datagram = CompressedDatagram::new(ctx_id, payload);
Ok((Datagram::Compressed(datagram), capsule))
}
async fn flush_pending_for_context(&self, ctx_id: VarInt) -> Vec<Bytes> {
let target = {
let mgr = self.context_manager.read().await;
mgr.get_target(ctx_id)
};
if let Some(target) = target {
let mut pending = self.pending_datagrams.write().await;
let now = Instant::now();
let mut payloads = Vec::new();
pending.retain(|d| {
if d.target == target {
if now.duration_since(d.created_at) < PENDING_DATAGRAM_MAX_AGE {
payloads.push(d.payload.clone());
}
false } else {
true }
});
payloads
} else {
Vec::new()
}
}
pub async fn cleanup_stale_pending(&self) -> usize {
let mut pending = self.pending_datagrams.write().await;
let before = pending.len();
let now = Instant::now();
pending.retain(|d| now.duration_since(d.created_at) < PENDING_DATAGRAM_MAX_AGE);
before - pending.len()
}
pub async fn decode_datagram(&self, data: &[u8]) -> RelayResult<(SocketAddr, Bytes)> {
if let Ok(datagram) = CompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data)) {
let mgr = self.context_manager.read().await;
if let Some(target) = mgr.get_target(datagram.context_id) {
self.stats.record_received(datagram.payload.len() as u64);
return Ok((target, datagram.payload));
}
}
if let Ok(datagram) = UncompressedDatagram::decode(&mut bytes::Bytes::copy_from_slice(data))
{
self.stats.record_received(datagram.payload.len() as u64);
return Ok((datagram.target, datagram.payload));
}
Err(RelayError::ProtocolError {
frame_type: 0,
reason: "Failed to decode datagram".into(),
})
}
pub fn record_sent(&self, bytes: usize) {
self.stats.record_sent(bytes as u64);
}
pub async fn close(&self) {
*self.state.write().await = RelayConnectionState::Closed;
self.target_to_context.write().await.clear();
self.pending_datagrams.write().await.clear();
tracing::info!(
relay = %self.relay_address,
"MASQUE relay client closed"
);
}
pub async fn active_contexts(&self) -> Vec<VarInt> {
let mgr = self.context_manager.read().await;
mgr.local_context_ids().collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
}
fn relay_addr() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000)
}
#[tokio::test]
async fn test_client_creation() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
assert_eq!(client.relay_address(), relay_addr());
assert!(!client.is_connected().await);
assert!(client.public_address().await.is_none());
}
#[tokio::test]
async fn test_connect_request() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let request = client.create_connect_request();
assert!(request.connect_udp_bind);
}
#[tokio::test]
async fn test_handle_success_response() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let public_addr = test_addr(12345);
let response = ConnectUdpResponse::success(Some(public_addr));
client.handle_connect_response(response).await.unwrap();
assert!(client.is_connected().await);
assert_eq!(client.public_address().await, Some(public_addr));
}
#[tokio::test]
async fn test_handle_error_response() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::error(503, "Server busy");
let result = client.handle_connect_response(response).await;
assert!(result.is_err());
assert_eq!(client.state().await, RelayConnectionState::Failed);
}
#[tokio::test]
async fn test_context_creation() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
let target = test_addr(8080);
let (ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
assert!(capsule.is_some());
assert!(matches!(capsule, Some(Capsule::CompressionAssign(_))));
assert_eq!(ctx_id.into_inner() % 2, 0);
}
#[tokio::test]
async fn test_handle_compression_ack() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
let target = test_addr(8080);
let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
let ack = CompressionAck::new(ctx_id);
let result = client.handle_capsule(Capsule::CompressionAck(ack)).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
assert_eq!(new_ctx_id, ctx_id);
assert!(capsule.is_none()); }
#[tokio::test]
async fn test_handle_compression_close() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
let target = test_addr(8080);
let (ctx_id, _) = client.get_or_create_context(target).await.unwrap();
let ack = CompressionAck::new(ctx_id);
client
.handle_capsule(Capsule::CompressionAck(ack))
.await
.unwrap();
let close = CompressionClose::new(ctx_id);
let result = client
.handle_capsule(Capsule::CompressionClose(close))
.await;
assert!(result.is_ok());
let (new_ctx_id, capsule) = client.get_or_create_context(target).await.unwrap();
assert_ne!(new_ctx_id, ctx_id); assert!(capsule.is_some()); }
#[tokio::test]
async fn test_create_datagram_compressed() {
let config = RelayClientConfig {
prefer_compressed: true,
..Default::default()
};
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
let target = test_addr(8080);
let payload = Bytes::from("Hello, relay!");
let (datagram, capsule) = client.create_datagram(target, payload).await.unwrap();
assert!(matches!(datagram, Datagram::Compressed(_)));
assert!(capsule.is_some());
}
#[tokio::test]
async fn test_client_close() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
assert!(client.is_connected().await);
client.close().await;
assert_eq!(client.state().await, RelayConnectionState::Closed);
}
#[tokio::test]
async fn test_stats() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
let stats = client.stats();
assert_eq!(stats.total_sent(), 0);
assert_eq!(stats.total_received(), 0);
client.record_sent(100);
assert_eq!(stats.total_sent(), 100);
assert_eq!(stats.datagrams_sent.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_flush_pending_returns_payloads() {
let config = RelayClientConfig {
prefer_compressed: true,
..Default::default()
};
let client = MasqueRelayClient::new(relay_addr(), config);
let response = ConnectUdpResponse::success(Some(test_addr(12345)));
client.handle_connect_response(response).await.unwrap();
let target = test_addr(8080);
let payload = Bytes::from("queued data");
let (_datagram, capsule) = client
.create_datagram(target, payload.clone())
.await
.unwrap();
assert!(capsule.is_some(), "First call should create a new context");
let ctx_id = match capsule.unwrap() {
Capsule::CompressionAssign(assign) => assign.context_id,
_ => panic!("Expected CompressionAssign capsule"),
};
let ack = CompressionAck::new(ctx_id);
client
.handle_capsule(Capsule::CompressionAck(ack))
.await
.unwrap();
let cleaned = client.cleanup_stale_pending().await;
assert_eq!(cleaned, 0, "All pending should have been flushed already");
}
#[tokio::test]
async fn test_cleanup_stale_pending() {
let config = RelayClientConfig::default();
let client = MasqueRelayClient::new(relay_addr(), config);
{
let mut pending = client.pending_datagrams.write().await;
pending.push(PendingDatagram {
target: test_addr(8080),
payload: Bytes::from("old data"),
created_at: Instant::now() - Duration::from_secs(15), });
pending.push(PendingDatagram {
target: test_addr(9090),
payload: Bytes::from("fresh data"),
created_at: Instant::now(), });
}
let cleaned = client.cleanup_stale_pending().await;
assert_eq!(cleaned, 1, "Should have cleaned 1 stale datagram");
let remaining = client.pending_datagrams.read().await;
assert_eq!(remaining.len(), 1, "One fresh datagram should remain");
assert_eq!(remaining[0].target, test_addr(9090));
}
}