use bytes::Bytes;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::masque::{
Capsule, ConnectUdpRequest, ConnectUdpResponse, Datagram, RelaySession, RelaySessionConfig,
RelaySessionState,
};
use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
#[derive(Debug, Clone)]
pub struct MasqueRelayConfig {
pub max_sessions: usize,
pub session_config: RelaySessionConfig,
pub cleanup_interval: Duration,
pub global_bandwidth_limit: u64,
pub require_authentication: bool,
}
impl Default for MasqueRelayConfig {
fn default() -> Self {
Self {
max_sessions: 1000,
session_config: RelaySessionConfig::default(),
cleanup_interval: Duration::from_secs(60),
global_bandwidth_limit: 100 * 1024 * 1024, require_authentication: true,
}
}
}
#[derive(Debug, Default)]
pub struct MasqueRelayStats {
pub sessions_created: AtomicU64,
pub active_sessions: AtomicU64,
pub sessions_terminated: AtomicU64,
pub bytes_relayed: AtomicU64,
pub datagrams_forwarded: AtomicU64,
pub auth_failures: AtomicU64,
pub rate_limit_rejections: AtomicU64,
}
impl MasqueRelayStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_session_created(&self) {
self.sessions_created.fetch_add(1, Ordering::Relaxed);
self.active_sessions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_session_terminated(&self) {
self.sessions_terminated.fetch_add(1, Ordering::Relaxed);
self.active_sessions.fetch_sub(1, Ordering::Relaxed);
}
pub fn record_bytes(&self, bytes: u64) {
self.bytes_relayed.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_datagram(&self) {
self.datagrams_forwarded.fetch_add(1, Ordering::Relaxed);
}
pub fn record_auth_failure(&self) {
self.auth_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn record_rate_limit(&self) {
self.rate_limit_rejections.fetch_add(1, Ordering::Relaxed);
}
pub fn current_active_sessions(&self) -> u64 {
self.active_sessions.load(Ordering::Relaxed)
}
pub fn total_bytes_relayed(&self) -> u64 {
self.bytes_relayed.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub struct OutboundDatagram {
pub target: SocketAddr,
pub payload: Bytes,
pub session_id: u64,
}
#[derive(Debug)]
pub enum DatagramResult {
Forward(OutboundDatagram),
Internal,
SessionNotFound,
Error(RelayError),
}
#[derive(Debug)]
pub struct MasqueRelayServer {
config: MasqueRelayConfig,
public_address: SocketAddr,
secondary_address: Option<SocketAddr>,
sessions: RwLock<HashMap<u64, RelaySession>>,
client_to_session: RwLock<HashMap<SocketAddr, u64>>,
next_session_id: AtomicU64,
stats: Arc<MasqueRelayStats>,
started_at: Instant,
bridged_connections: AtomicU64,
}
impl MasqueRelayServer {
pub fn new(config: MasqueRelayConfig, public_address: SocketAddr) -> Self {
Self {
config,
public_address,
secondary_address: None,
sessions: RwLock::new(HashMap::new()),
client_to_session: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
stats: Arc::new(MasqueRelayStats::new()),
started_at: Instant::now(),
bridged_connections: AtomicU64::new(0),
}
}
pub fn new_dual_stack(
config: MasqueRelayConfig,
ipv4_address: SocketAddr,
ipv6_address: SocketAddr,
) -> Self {
let (primary, secondary) = if ipv4_address.is_ipv4() {
(ipv4_address, ipv6_address)
} else {
(ipv6_address, ipv4_address)
};
Self {
config,
public_address: primary,
secondary_address: Some(secondary),
sessions: RwLock::new(HashMap::new()),
client_to_session: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
stats: Arc::new(MasqueRelayStats::new()),
started_at: Instant::now(),
bridged_connections: AtomicU64::new(0),
}
}
pub fn supports_dual_stack(&self) -> bool {
if let Some(secondary) = self.secondary_address {
self.public_address.is_ipv4() != secondary.is_ipv4()
} else {
false
}
}
pub async fn can_bridge(&self, source: SocketAddr, target: SocketAddr) -> bool {
let source_v4 = source.is_ipv4();
let target_v4 = target.is_ipv4();
if source_v4 == target_v4 {
return true;
}
self.supports_dual_stack()
}
pub fn address_for_target(&self, target: &SocketAddr) -> SocketAddr {
if let Some(secondary) = self.secondary_address {
let target_v4 = target.is_ipv4();
if self.public_address.is_ipv4() == target_v4 {
self.public_address
} else {
secondary
}
} else {
self.public_address
}
}
pub fn secondary_address(&self) -> Option<SocketAddr> {
self.secondary_address
}
pub fn bridged_connection_count(&self) -> u64 {
self.bridged_connections.load(Ordering::Relaxed)
}
fn record_bridged_connection(&self) {
self.bridged_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn stats(&self) -> Arc<MasqueRelayStats> {
Arc::clone(&self.stats)
}
pub fn uptime(&self) -> Duration {
self.started_at.elapsed()
}
pub fn public_address(&self) -> SocketAddr {
self.public_address
}
pub async fn handle_connect_request(
&self,
request: &ConnectUdpRequest,
client_addr: SocketAddr,
) -> RelayResult<ConnectUdpResponse> {
let current_sessions = self.stats.current_active_sessions();
if current_sessions >= self.config.max_sessions as u64 {
return Ok(ConnectUdpResponse::error(
503,
"Server at capacity".to_string(),
));
}
{
let client_sessions = self.client_to_session.read().await;
if client_sessions.contains_key(&client_addr) {
return Ok(ConnectUdpResponse::error(
409,
"Session already exists for this client".to_string(),
));
}
}
let requires_bridging = if let Some(target) = request.target_address() {
let client_v4 = client_addr.is_ipv4();
let target_v4 = target.is_ipv4();
client_v4 != target_v4
} else {
false
};
if requires_bridging && !self.supports_dual_stack() {
return Ok(ConnectUdpResponse::error(
501,
"IPv4/IPv6 bridging not supported by this relay".to_string(),
));
}
let advertised_address = if client_addr.is_ipv4() {
if self.public_address.is_ipv4() {
self.public_address
} else {
self.secondary_address.unwrap_or(self.public_address)
}
} else if self.public_address.is_ipv6() {
self.public_address
} else {
self.secondary_address.unwrap_or(self.public_address)
};
let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
let mut session = RelaySession::new(
session_id,
self.config.session_config.clone(),
advertised_address,
);
session.set_client_address(client_addr);
if requires_bridging {
session.set_bridging(true);
}
session.activate()?;
{
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, session);
}
{
let mut client_map = self.client_to_session.write().await;
client_map.insert(client_addr, session_id);
}
self.stats.record_session_created();
if requires_bridging {
self.record_bridged_connection();
}
tracing::info!(
session_id = session_id,
client = %client_addr,
public_addr = %advertised_address,
bridging = requires_bridging,
dual_stack = self.supports_dual_stack(),
"MASQUE relay session created"
);
Ok(ConnectUdpResponse::success(Some(advertised_address)))
}
pub async fn get_session_for_client(&self, client_addr: SocketAddr) -> Option<SessionInfo> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map.get(&client_addr).copied()?
};
self.get_session_info(session_id).await
}
pub async fn terminate_session_for_client(&self, client_addr: SocketAddr) {
let _ = self.close_session_by_client(client_addr).await;
}
pub async fn forward_datagram(
&self,
client_addr: SocketAddr,
_target: SocketAddr,
payload: Bytes,
) -> RelayResult<()> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map
.get(&client_addr)
.copied()
.ok_or(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::NotFound,
})?
};
let sessions = self.sessions.read().await;
let session = sessions.get(&session_id).ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
if !session.check_rate_limit(payload.len()) {
self.stats.record_rate_limit();
return Err(RelayError::RateLimitExceeded {
retry_after_ms: 1000, });
}
self.stats.record_bytes(payload.len() as u64);
self.stats.record_datagram();
Ok(())
}
pub async fn handle_capsule(
&self,
client_addr: SocketAddr,
capsule: Capsule,
) -> RelayResult<Option<Capsule>> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map
.get(&client_addr)
.copied()
.ok_or(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::NotFound,
})?
};
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
session.handle_capsule(capsule)
}
pub async fn handle_client_datagram(
&self,
client_addr: SocketAddr,
datagram: Datagram,
payload: Bytes,
) -> DatagramResult {
let session_id = {
let client_map = self.client_to_session.read().await;
match client_map.get(&client_addr) {
Some(&id) => id,
None => return DatagramResult::SessionNotFound,
}
};
let target = {
let sessions = self.sessions.read().await;
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return DatagramResult::SessionNotFound,
};
match session.resolve_target(&datagram) {
Some(t) => t,
None => {
return DatagramResult::Error(RelayError::ProtocolError {
frame_type: 0x00,
reason: "Unknown context ID".into(),
});
}
}
};
self.stats.record_bytes(payload.len() as u64);
self.stats.record_datagram();
DatagramResult::Forward(OutboundDatagram {
target,
payload,
session_id,
})
}
pub async fn handle_target_datagram(
&self,
session_id: u64,
source: SocketAddr,
payload: Bytes,
) -> RelayResult<(SocketAddr, Bytes)> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
let client_addr = session.client_address().ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::InvalidState {
current_state: "no client address".into(),
expected_state: "client address set".into(),
},
})?;
let ctx_id = session.context_for_target(source)?;
let datagram = crate::masque::CompressedDatagram::new(ctx_id, payload.clone());
let encoded = datagram.encode();
self.stats.record_bytes(encoded.len() as u64);
self.stats.record_datagram();
Ok((client_addr, encoded))
}
pub async fn close_session(&self, session_id: u64) -> RelayResult<()> {
let client_addr = {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
let addr = session.client_address();
session.close();
addr
};
{
let mut sessions = self.sessions.write().await;
sessions.remove(&session_id);
}
if let Some(addr) = client_addr {
let mut client_map = self.client_to_session.write().await;
client_map.remove(&addr);
}
self.stats.record_session_terminated();
tracing::info!(session_id = session_id, "MASQUE relay session closed");
Ok(())
}
pub async fn close_session_by_client(&self, client_addr: SocketAddr) -> RelayResult<()> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map
.get(&client_addr)
.copied()
.ok_or(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::NotFound,
})?
};
self.close_session(session_id).await
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let expired_ids: Vec<u64> = {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, s)| s.is_timed_out())
.map(|(id, _)| *id)
.collect()
};
let count = expired_ids.len();
for session_id in expired_ids {
if let Err(e) = self.close_session(session_id).await {
tracing::warn!(
session_id = session_id,
error = %e,
"Failed to close expired session"
);
}
}
if count > 0 {
tracing::debug!(count = count, "Cleaned up expired MASQUE sessions");
}
count
}
pub async fn session_count(&self) -> usize {
let sessions = self.sessions.read().await;
sessions.len()
}
pub async fn get_session_info(&self, session_id: u64) -> Option<SessionInfo> {
let sessions = self.sessions.read().await;
sessions.get(&session_id).map(|s| SessionInfo {
session_id: s.session_id(),
state: s.state(),
public_address: s.public_address(),
client_address: s.client_address(),
duration: s.duration(),
stats: s.stats(),
is_bridging: s.is_bridging(),
})
}
pub async fn active_session_ids(&self) -> Vec<u64> {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, s)| s.is_active())
.map(|(id, _)| *id)
.collect()
}
}
#[derive(Debug)]
pub struct SessionInfo {
pub session_id: u64,
pub state: RelaySessionState,
pub public_address: SocketAddr,
pub client_address: Option<SocketAddr>,
pub duration: Duration,
pub stats: Arc<crate::masque::RelaySessionStats>,
pub is_bridging: bool,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn test_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
}
fn client_addr(id: u8) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, id)), 12345)
}
#[tokio::test]
async fn test_server_creation() {
let config = MasqueRelayConfig::default();
let public_addr = test_addr(9000);
let server = MasqueRelayServer::new(config, public_addr);
assert_eq!(server.public_address(), public_addr);
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_connect_request_creates_session() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
assert!(response.proxy_public_address.is_some());
assert_eq!(server.session_count().await, 1);
}
#[tokio::test]
async fn test_duplicate_client_rejected() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
let response1 = server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(response1.status, 200);
let response2 = server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(response2.status, 409);
}
#[tokio::test]
async fn test_session_limit() {
let config = MasqueRelayConfig {
max_sessions: 2,
..Default::default()
};
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::bind_any();
for i in 1..=2 {
let response = server
.handle_connect_request(&request, client_addr(i))
.await
.unwrap();
assert_eq!(response.status, 200);
}
let response = server
.handle_connect_request(&request, client_addr(3))
.await
.unwrap();
assert_eq!(response.status, 503);
}
#[tokio::test]
async fn test_target_request_accepted() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::target(test_addr(8080));
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
}
#[tokio::test]
async fn test_close_session() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
assert_eq!(server.session_count().await, 1);
let session_ids = server.active_session_ids().await;
assert_eq!(session_ids.len(), 1);
server.close_session(session_ids[0]).await.unwrap();
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_close_session_by_client() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(server.session_count().await, 1);
server.close_session_by_client(client).await.unwrap();
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_server_stats() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let stats = server.stats();
assert_eq!(stats.current_active_sessions(), 0);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(stats.current_active_sessions(), 1);
assert_eq!(stats.sessions_created.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_get_session_info() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client)
.await
.unwrap();
let session_ids = server.active_session_ids().await;
let info = server.get_session_info(session_ids[0]).await.unwrap();
assert_eq!(info.client_address, Some(client));
assert_eq!(info.state, RelaySessionState::Active);
}
fn ipv4_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), port)
}
fn ipv6_addr(port: u16) -> SocketAddr {
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
port,
)
}
fn ipv4_client(id: u8) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, id)), 12345)
}
fn ipv6_client(id: u8) -> SocketAddr {
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, id.into())),
12345,
)
}
#[tokio::test]
async fn test_dual_stack_creation() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new_dual_stack(config, ipv4_addr(9000), ipv6_addr(9000));
assert!(server.supports_dual_stack());
assert!(server.secondary_address().is_some());
}
#[tokio::test]
async fn test_single_stack_no_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(!server.supports_dual_stack());
assert!(server.secondary_address().is_none());
}
#[tokio::test]
async fn test_can_bridge_same_version() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(server.can_bridge(ipv4_client(1), ipv4_addr(8080)).await);
}
#[tokio::test]
async fn test_can_bridge_different_version_without_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(!server.can_bridge(ipv4_client(1), ipv6_addr(8080)).await);
}
#[tokio::test]
async fn test_can_bridge_different_version_with_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new_dual_stack(config, ipv4_addr(9000), ipv6_addr(9000));
assert!(server.can_bridge(ipv4_client(1), ipv6_addr(8080)).await);
assert!(server.can_bridge(ipv6_client(1), ipv4_addr(8080)).await);
}
#[tokio::test]
async fn test_address_for_target_ipv4() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let addr = server.address_for_target(&ipv4_addr(8080));
assert!(addr.is_ipv4());
}
#[tokio::test]
async fn test_address_for_target_ipv6() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let addr = server.address_for_target(&ipv6_addr(8080));
assert!(addr.is_ipv6());
}
#[tokio::test]
async fn test_bridging_connect_request_rejected_without_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
let request = ConnectUdpRequest::target(ipv6_addr(8080));
let response = server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(response.status, 501);
}
#[tokio::test]
async fn test_ipv4_client_session() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(response.status, 200);
let public_addr = response.proxy_public_address.unwrap();
assert!(public_addr.is_ipv4());
}
#[tokio::test]
async fn test_ipv6_client_session() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, ipv6_client(1))
.await
.unwrap();
assert_eq!(response.status, 200);
let public_addr = response.proxy_public_address.unwrap();
assert!(public_addr.is_ipv6());
}
#[tokio::test]
async fn test_bridged_connection_count() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
assert_eq!(server.bridged_connection_count(), 0);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(server.bridged_connection_count(), 0);
}
#[tokio::test]
async fn test_session_bridging_flag() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
let session_ids = server.active_session_ids().await;
let info = server.get_session_info(session_ids[0]).await.unwrap();
assert!(!info.is_bridging);
}
}