use bytes::Bytes;
use parking_lot::RwLock as ParkingRwLock;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use crate::VarInt;
use crate::high_level::Connection as QuicConnection;
use crate::masque::{
Capsule, CompressedDatagram, ConnectUdpRequest, ConnectUdpResponse, Datagram, RelaySession,
RelaySessionConfig, RelaySessionState, UncompressedDatagram,
};
use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
#[derive(Debug, Clone)]
pub struct MasqueRelayConfig {
pub max_sessions: usize,
pub session_config: RelaySessionConfig,
pub cleanup_interval: Duration,
pub global_bandwidth_limit: u64,
pub require_authentication: bool,
}
impl Default for MasqueRelayConfig {
fn default() -> Self {
Self {
max_sessions: 1000,
session_config: RelaySessionConfig::default(),
cleanup_interval: Duration::from_secs(60),
global_bandwidth_limit: 100 * 1024 * 1024, require_authentication: true,
}
}
}
#[derive(Debug, Default)]
pub struct MasqueRelayStats {
pub sessions_created: AtomicU64,
pub active_sessions: AtomicU64,
pub sessions_terminated: AtomicU64,
pub bytes_relayed: AtomicU64,
pub datagrams_forwarded: AtomicU64,
pub auth_failures: AtomicU64,
pub rate_limit_rejections: AtomicU64,
}
impl MasqueRelayStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_session_created(&self) {
self.sessions_created.fetch_add(1, Ordering::Relaxed);
self.active_sessions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_session_terminated(&self) {
self.sessions_terminated.fetch_add(1, Ordering::Relaxed);
self.active_sessions.fetch_sub(1, Ordering::Relaxed);
}
pub fn record_bytes(&self, bytes: u64) {
self.bytes_relayed.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_datagram(&self) {
self.datagrams_forwarded.fetch_add(1, Ordering::Relaxed);
}
pub fn record_auth_failure(&self) {
self.auth_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn record_rate_limit(&self) {
self.rate_limit_rejections.fetch_add(1, Ordering::Relaxed);
}
pub fn current_active_sessions(&self) -> u64 {
self.active_sessions.load(Ordering::Relaxed)
}
pub fn total_bytes_relayed(&self) -> u64 {
self.bytes_relayed.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub struct OutboundDatagram {
pub target: SocketAddr,
pub payload: Bytes,
pub session_id: u64,
}
#[derive(Debug)]
pub enum DatagramResult {
Forward(OutboundDatagram),
Internal,
SessionNotFound,
Error(RelayError),
}
#[derive(Debug)]
pub struct MasqueRelayServer {
config: MasqueRelayConfig,
default_public_address: SocketAddr,
default_secondary_address: Option<SocketAddr>,
public_address: ParkingRwLock<SocketAddr>,
secondary_address: ParkingRwLock<Option<SocketAddr>>,
sessions: RwLock<HashMap<u64, RelaySession>>,
client_to_session: RwLock<HashMap<SocketAddr, u64>>,
next_session_id: AtomicU64,
stats: Arc<MasqueRelayStats>,
started_at: Instant,
bridged_connections: AtomicU64,
}
impl MasqueRelayServer {
pub fn new(config: MasqueRelayConfig, public_address: SocketAddr) -> Self {
Self {
config,
default_public_address: public_address,
default_secondary_address: None,
public_address: ParkingRwLock::new(public_address),
secondary_address: ParkingRwLock::new(None),
sessions: RwLock::new(HashMap::new()),
client_to_session: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
stats: Arc::new(MasqueRelayStats::new()),
started_at: Instant::now(),
bridged_connections: AtomicU64::new(0),
}
}
pub fn new_dual_stack(
config: MasqueRelayConfig,
ipv4_address: SocketAddr,
ipv6_address: SocketAddr,
) -> Self {
let (primary, secondary) = if ipv4_address.is_ipv4() {
(ipv4_address, ipv6_address)
} else {
(ipv6_address, ipv4_address)
};
Self {
config,
default_public_address: primary,
default_secondary_address: Some(secondary),
public_address: ParkingRwLock::new(primary),
secondary_address: ParkingRwLock::new(Some(secondary)),
sessions: RwLock::new(HashMap::new()),
client_to_session: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
stats: Arc::new(MasqueRelayStats::new()),
started_at: Instant::now(),
bridged_connections: AtomicU64::new(0),
}
}
pub fn supports_dual_stack(&self) -> bool {
let primary = *self.public_address.read();
if let Some(secondary) = *self.secondary_address.read() {
primary.is_ipv4() != secondary.is_ipv4()
} else {
false
}
}
pub async fn can_bridge(&self, source: SocketAddr, target: SocketAddr) -> bool {
let source_v4 = source.is_ipv4();
let target_v4 = target.is_ipv4();
if source_v4 == target_v4 {
return true;
}
self.supports_dual_stack()
}
pub fn address_for_target(&self, target: &SocketAddr) -> SocketAddr {
let primary = *self.public_address.read();
if let Some(secondary) = *self.secondary_address.read() {
let target_v4 = target.is_ipv4();
if primary.is_ipv4() == target_v4 {
primary
} else {
secondary
}
} else {
primary
}
}
pub fn secondary_address(&self) -> Option<SocketAddr> {
*self.secondary_address.read()
}
fn default_address_for_family(&self, want_ipv4: bool) -> Option<SocketAddr> {
if self.default_public_address.is_ipv4() == want_ipv4 {
Some(self.default_public_address)
} else {
self.default_secondary_address
.filter(|addr| addr.is_ipv4() == want_ipv4)
}
}
pub fn bridged_connection_count(&self) -> u64 {
self.bridged_connections.load(Ordering::Relaxed)
}
fn record_bridged_connection(&self) {
self.bridged_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn stats(&self) -> Arc<MasqueRelayStats> {
Arc::clone(&self.stats)
}
pub fn uptime(&self) -> Duration {
self.started_at.elapsed()
}
pub fn public_address(&self) -> SocketAddr {
*self.public_address.read()
}
pub fn set_public_address(&self, addr: SocketAddr) {
let mut public_address = self.public_address.write();
let mut secondary_address = self.secondary_address.write();
if *public_address == addr || secondary_address.is_some_and(|current| current == addr) {
return;
}
let old = if public_address.is_ipv4() == addr.is_ipv4() {
let old = *public_address;
*public_address = addr;
old
} else if let Some(current_secondary) = *secondary_address {
if current_secondary.is_ipv4() == addr.is_ipv4() {
let old = current_secondary;
*secondary_address = Some(addr);
old
} else {
let old = *public_address;
*public_address = addr;
old
}
} else {
let old = *public_address;
*public_address = addr;
old
};
tracing::info!(
old = %old,
new = %addr,
"Relay server public address updated"
);
}
pub fn reconcile_public_addresses(&self, ipv4: Option<SocketAddr>, ipv6: Option<SocketAddr>) {
let resolved_ipv4 = ipv4.or_else(|| self.default_address_for_family(true));
let resolved_ipv6 = ipv6.or_else(|| self.default_address_for_family(false));
let (new_primary, new_secondary) = match (resolved_ipv4, resolved_ipv6) {
(Some(v4), Some(v6)) => (v4, Some(v6)),
(Some(v4), None) => (v4, None),
(None, Some(v6)) => (v6, None),
(None, None) => (self.default_public_address, self.default_secondary_address),
};
let mut public_address = self.public_address.write();
let mut secondary_address = self.secondary_address.write();
if *public_address == new_primary && *secondary_address == new_secondary {
return;
}
let old_primary = *public_address;
let old_secondary = *secondary_address;
*public_address = new_primary;
*secondary_address = new_secondary;
tracing::info!(
old_primary = %old_primary,
old_secondary = ?old_secondary,
new_primary = %new_primary,
new_secondary = ?new_secondary,
"Relay server public addresses reconciled"
);
}
pub async fn handle_connect_request(
&self,
request: &ConnectUdpRequest,
client_addr: SocketAddr,
) -> RelayResult<ConnectUdpResponse> {
let current_sessions = self.stats.current_active_sessions();
if current_sessions >= self.config.max_sessions as u64 {
return Ok(ConnectUdpResponse::error(
503,
"Server at capacity".to_string(),
));
}
{
let client_sessions = self.client_to_session.read().await;
if client_sessions.contains_key(&client_addr) {
return Ok(ConnectUdpResponse::error(
409,
"Session already exists for this client".to_string(),
));
}
}
let requires_bridging = if let Some(target) = request.target_address() {
let client_v4 = client_addr.is_ipv4();
let target_v4 = target.is_ipv4();
client_v4 != target_v4
} else {
false
};
if requires_bridging && !self.supports_dual_stack() {
return Ok(ConnectUdpResponse::error(
501,
"IPv4/IPv6 bridging not supported by this relay".to_string(),
));
}
let public_address = self.public_address();
let secondary_address = self.secondary_address();
let public_ip = if client_addr.is_ipv4() {
if public_address.is_ipv4() {
public_address.ip()
} else {
secondary_address.unwrap_or(public_address).ip()
}
} else if public_address.is_ipv6() {
public_address.ip()
} else {
secondary_address.unwrap_or(public_address).ip()
};
let bind_addr: SocketAddr = if client_addr.is_ipv4() {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
} else {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
};
let udp_socket =
UdpSocket::bind(bind_addr)
.await
.map_err(|e| RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::InvalidState {
current_state: format!("UDP bind failed: {}", e),
expected_state: "bound".into(),
},
})?;
let bound_port = udp_socket
.local_addr()
.map_err(|e| RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::InvalidState {
current_state: format!("Failed to get bound address: {}", e),
expected_state: "address available".into(),
},
})?
.port();
let advertised_address = SocketAddr::new(public_ip, bound_port);
let udp_socket = Arc::new(udp_socket);
let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
let mut session = RelaySession::new(
session_id,
self.config.session_config.clone(),
advertised_address,
);
session.set_client_address(client_addr);
session.set_udp_socket(udp_socket);
if requires_bridging {
session.set_bridging(true);
}
session.activate()?;
{
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, session);
}
{
let mut client_map = self.client_to_session.write().await;
client_map.insert(client_addr, session_id);
}
self.stats.record_session_created();
if requires_bridging {
self.record_bridged_connection();
}
tracing::info!(
session_id = session_id,
client = %client_addr,
public_addr = %advertised_address,
bound_port = bound_port,
bridging = requires_bridging,
dual_stack = self.supports_dual_stack(),
"MASQUE relay session created with bound UDP socket"
);
Ok(ConnectUdpResponse::success(Some(advertised_address)))
}
pub async fn get_session_for_client(&self, client_addr: SocketAddr) -> Option<SessionInfo> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map.get(&client_addr).copied()?
};
self.get_session_info(session_id).await
}
pub async fn terminate_session_for_client(&self, client_addr: SocketAddr) {
let _ = self.close_session_by_client(client_addr).await;
}
pub async fn handle_capsule(
&self,
client_addr: SocketAddr,
capsule: Capsule,
) -> RelayResult<Option<Capsule>> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map
.get(&client_addr)
.copied()
.ok_or(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::NotFound,
})?
};
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
session.handle_capsule(capsule)
}
pub async fn handle_client_datagram(
&self,
client_addr: SocketAddr,
datagram: Datagram,
payload: Bytes,
) -> DatagramResult {
let session_id = {
let client_map = self.client_to_session.read().await;
match client_map.get(&client_addr) {
Some(&id) => id,
None => return DatagramResult::SessionNotFound,
}
};
let target = {
let sessions = self.sessions.read().await;
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return DatagramResult::SessionNotFound,
};
match session.resolve_target(&datagram) {
Some(t) => t,
None => {
return DatagramResult::Error(RelayError::ProtocolError {
frame_type: 0x00,
reason: "Unknown context ID".into(),
});
}
}
};
self.stats.record_bytes(payload.len() as u64);
self.stats.record_datagram();
DatagramResult::Forward(OutboundDatagram {
target,
payload,
session_id,
})
}
pub async fn handle_target_datagram(
&self,
session_id: u64,
source: SocketAddr,
payload: Bytes,
) -> RelayResult<(SocketAddr, Bytes)> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
let client_addr = session.client_address().ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::InvalidState {
current_state: "no client address".into(),
expected_state: "client address set".into(),
},
})?;
let ctx_id = session.context_for_target(source)?;
let datagram = crate::masque::CompressedDatagram::new(ctx_id, payload.clone());
let encoded = datagram.encode();
self.stats.record_bytes(encoded.len() as u64);
self.stats.record_datagram();
Ok((client_addr, encoded))
}
pub async fn run_forwarding_loop(
self: &Arc<Self>,
session_id: u64,
connection: QuicConnection,
) {
let udp_socket = {
let sessions = self.sessions.read().await;
match sessions.get(&session_id) {
Some(s) => s.udp_socket().cloned(),
None => {
tracing::warn!(session_id, "Cannot start forwarding: session not found");
return;
}
}
};
let socket = match udp_socket {
Some(s) => s,
None => {
tracing::warn!(session_id, "Cannot start forwarding: no UDP socket bound");
return;
}
};
tracing::info!(
session_id,
bound_addr = %socket.local_addr().map(|a| a.to_string()).unwrap_or_default(),
"Starting relay forwarding loop"
);
let server = Arc::clone(self);
let server2 = Arc::clone(self);
let server3 = Arc::clone(self);
let socket2 = Arc::clone(&socket);
let conn2 = connection.clone();
tokio::select! {
_ = async {
let mut buf = vec![0u8; 65536];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, source)) => {
let payload = Bytes::copy_from_slice(&buf[..len]);
tracing::trace!(
session_id,
source = %source,
len,
"Relay: received UDP from target"
);
let datagram = UncompressedDatagram::new(
VarInt::from_u32(0),
source,
payload.clone(),
);
let encoded = datagram.encode();
server.stats.record_bytes(encoded.len() as u64);
server.stats.record_datagram();
if let Err(e) = connection.send_datagram(encoded) {
let err_str = e.to_string();
if err_str.contains("too large") || err_str.contains("TooLarge") {
tracing::warn!(
target: "ant_quic::silent_drop",
kind = "relay_oversized_datagram",
session_id,
len,
"skipping oversized datagram for relay"
);
continue;
} else {
tracing::warn!(
target: "ant_quic::silent_drop",
kind = "relay_send_datagram_fatal",
session_id,
error = %e,
"fatal datagram send error, stopping UDP→QUIC"
);
break;
}
}
}
Err(e) => {
tracing::debug!(
session_id,
error = %e,
"UDP socket recv error, stopping UDP→QUIC"
);
break;
}
}
}
} => {},
_ = async {
loop {
match conn2.read_datagram().await {
Ok(data) => {
let mut cursor = data.clone();
match UncompressedDatagram::decode(&mut cursor) {
Ok(datagram) => {
let target = datagram.target;
let payload = &datagram.payload;
tracing::trace!(
session_id,
target = %target,
len = payload.len(),
"Relay: forwarding to target via UDP"
);
server2.stats.record_bytes(payload.len() as u64);
server2.stats.record_datagram();
if let Err(e) = socket2.send_to(payload, target).await {
tracing::warn!(
session_id,
target = %target,
error = %e,
"Failed to send UDP to target"
);
}
}
Err(_) => {
let mut cursor2 = data.clone();
if let Ok(compressed) = CompressedDatagram::decode(&mut cursor2) {
let client_addr = conn2.remote_address();
let datagram = Datagram::Compressed(compressed);
let payload_clone = datagram.payload().clone();
match server2.handle_client_datagram(
client_addr, datagram, payload_clone,
).await {
DatagramResult::Forward(outbound) => {
server2.stats.record_bytes(outbound.payload.len() as u64);
server2.stats.record_datagram();
if let Err(e) = socket2.send_to(
&outbound.payload, outbound.target,
).await {
tracing::warn!(
session_id,
target = %outbound.target,
error = %e,
"Failed to send UDP to target (compressed)"
);
}
}
DatagramResult::Error(e) => {
tracing::debug!(
session_id,
error = %e,
"Failed to process compressed datagram"
);
}
_ => {}
}
} else {
tracing::debug!(
session_id,
len = data.len(),
"Failed to decode relay datagram, skipping"
);
}
}
}
}
Err(e) => {
tracing::debug!(
session_id,
error = %e,
"QUIC connection closed, stopping QUIC→UDP"
);
break;
}
}
}
} => {},
_ = async {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(5));
tick.tick().await; loop {
tick.tick().await;
tracing::warn!(
target: "ant_quic::relay_traffic",
session_id,
bytes_forwarded = server3.stats.total_bytes_relayed(),
datagrams = server3.stats.datagrams_forwarded.load(Ordering::Relaxed),
"relay session traffic"
);
}
} => {},
}
tracing::info!(session_id, "Relay forwarding loop ended");
if let Err(e) = self.close_session(session_id).await {
tracing::debug!(session_id, error = %e, "Error closing session after forwarding ended");
}
}
pub async fn run_stream_forwarding_loop(
self: &Arc<Self>,
session_id: u64,
mut send_stream: crate::high_level::SendStream,
mut recv_stream: crate::high_level::RecvStream,
) {
let udp_socket = {
let sessions = self.sessions.read().await;
match sessions.get(&session_id) {
Some(s) => s.udp_socket().cloned(),
None => {
tracing::warn!(
session_id,
"Cannot start stream forwarding: session not found"
);
return;
}
}
};
let socket = match udp_socket {
Some(s) => s,
None => {
tracing::warn!(session_id, "Cannot start stream forwarding: no UDP socket");
return;
}
};
tracing::info!(
session_id,
bound_addr = %socket.local_addr().map(|a| a.to_string()).unwrap_or_default(),
"Starting stream-based relay forwarding loop"
);
let socket2 = Arc::clone(&socket);
let stats = self.stats();
let stats2 = self.stats();
tokio::select! {
_ = async {
let mut buf = vec![0u8; 65536];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, source)) => {
let payload = Bytes::copy_from_slice(&buf[..len]);
tracing::trace!(
session_id, source = %source, len,
"Stream relay: received UDP from target"
);
let datagram = UncompressedDatagram::new(
VarInt::from_u32(0), source, payload,
);
let encoded = datagram.encode();
let frame_len = encoded.len() as u32;
if let Err(e) = send_stream.write_all(&frame_len.to_be_bytes()).await {
tracing::debug!(session_id, error = %e, "Stream write error (length)");
break;
}
if let Err(e) = send_stream.write_all(&encoded).await {
tracing::debug!(session_id, error = %e, "Stream write error (data)");
break;
}
stats.record_bytes(encoded.len() as u64);
stats.record_datagram();
}
Err(e) => {
tracing::debug!(session_id, error = %e, "UDP recv error");
break;
}
}
}
} => {},
_ = async {
loop {
let mut len_buf = [0u8; 4];
if let Err(e) = recv_stream.read_exact(&mut len_buf).await {
tracing::debug!(session_id, error = %e, "Stream read error (length)");
break;
}
let frame_len = u32::from_be_bytes(len_buf) as usize;
if frame_len > 65536 {
tracing::warn!(session_id, frame_len, "Oversized stream frame, dropping");
break;
}
let mut frame_buf = vec![0u8; frame_len];
if let Err(e) = recv_stream.read_exact(&mut frame_buf).await {
tracing::debug!(session_id, error = %e, "Stream read error (data)");
break;
}
let mut cursor = Bytes::from(frame_buf);
match UncompressedDatagram::decode(&mut cursor) {
Ok(datagram) => {
tracing::trace!(
session_id, target = %datagram.target,
len = datagram.payload.len(),
"Stream relay: forwarding to target via UDP"
);
stats2.record_bytes(datagram.payload.len() as u64);
stats2.record_datagram();
if let Err(e) = socket2.send_to(&datagram.payload, datagram.target).await {
tracing::warn!(
session_id, target = %datagram.target, error = %e,
"Failed to send UDP to target"
);
}
}
Err(_) => {
tracing::debug!(session_id, "Failed to decode stream frame");
}
}
}
} => {},
}
tracing::info!(session_id, "Stream-based relay forwarding loop ended");
if let Err(e) = self.close_session(session_id).await {
tracing::debug!(session_id, error = %e, "Error closing session");
}
}
pub async fn close_session(&self, session_id: u64) -> RelayResult<()> {
let mut sessions = self.sessions.write().await;
let mut client_map = self.client_to_session.write().await;
let mut session = sessions
.remove(&session_id)
.ok_or(RelayError::SessionError {
session_id: Some(session_id as u32),
kind: SessionErrorKind::NotFound,
})?;
let client_addr = session.client_address();
session.close();
if let Some(addr) = client_addr {
client_map.remove(&addr);
}
self.stats.record_session_terminated();
tracing::info!(session_id = session_id, "MASQUE relay session closed");
Ok(())
}
pub async fn close_session_by_client(&self, client_addr: SocketAddr) -> RelayResult<()> {
let session_id = {
let client_map = self.client_to_session.read().await;
client_map
.get(&client_addr)
.copied()
.ok_or(RelayError::SessionError {
session_id: None,
kind: SessionErrorKind::NotFound,
})?
};
self.close_session(session_id).await
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let expired_ids: Vec<u64> = {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, s)| s.is_timed_out())
.map(|(id, _)| *id)
.collect()
};
let count = expired_ids.len();
for session_id in expired_ids {
if let Err(e) = self.close_session(session_id).await {
tracing::warn!(
session_id = session_id,
error = %e,
"Failed to close expired session"
);
}
}
if count > 0 {
tracing::debug!(count = count, "Cleaned up expired MASQUE sessions");
}
count
}
pub fn spawn_cleanup_task(server: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let weak = Arc::downgrade(server);
let interval_duration = server.config.cleanup_interval;
tokio::spawn(async move {
let mut interval = tokio::time::interval(interval_duration);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let Some(server) = weak.upgrade() else {
break;
};
let cleaned = server.cleanup_expired_sessions().await;
if cleaned > 0 {
let remaining = server.session_count().await;
tracing::info!(
cleaned = cleaned,
remaining = remaining,
"Reaped expired MASQUE relay sessions"
);
}
}
})
}
pub async fn session_count(&self) -> usize {
let sessions = self.sessions.read().await;
sessions.len()
}
pub async fn get_session_info(&self, session_id: u64) -> Option<SessionInfo> {
let sessions = self.sessions.read().await;
sessions.get(&session_id).map(|s| SessionInfo {
session_id: s.session_id(),
state: s.state(),
public_address: s.public_address(),
client_address: s.client_address(),
duration: s.duration(),
stats: s.stats(),
is_bridging: s.is_bridging(),
})
}
pub async fn active_session_ids(&self) -> Vec<u64> {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, s)| s.is_active())
.map(|(id, _)| *id)
.collect()
}
}
#[derive(Debug)]
pub struct SessionInfo {
pub session_id: u64,
pub state: RelaySessionState,
pub public_address: SocketAddr,
pub client_address: Option<SocketAddr>,
pub duration: Duration,
pub stats: Arc<crate::masque::RelaySessionStats>,
pub is_bridging: bool,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn test_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
}
fn client_addr(id: u8) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, id)), 12345)
}
#[tokio::test]
async fn test_server_creation() {
let config = MasqueRelayConfig::default();
let public_addr = test_addr(9000);
let server = MasqueRelayServer::new(config, public_addr);
assert_eq!(server.public_address(), public_addr);
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_set_public_address_updates_future_advertisements() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let refreshed_addr: SocketAddr = "198.51.100.77:9000".parse().expect("valid addr");
server.set_public_address(refreshed_addr);
let response = server
.handle_connect_request(&ConnectUdpRequest::bind_any(), client_addr(9))
.await
.expect("connect request should succeed");
assert_eq!(
response
.proxy_public_address
.expect("proxy public address should be present")
.ip(),
refreshed_addr.ip()
);
}
#[tokio::test]
async fn test_reconcile_public_addresses_falls_back_to_defaults_per_family() {
let config = MasqueRelayConfig::default();
let default_ipv4: SocketAddr = "0.0.0.0:9000".parse().expect("valid addr");
let default_ipv6: SocketAddr = "[::]:9000".parse().expect("valid addr");
let server = MasqueRelayServer::new_dual_stack(config, default_ipv4, default_ipv6);
let observed_ipv4: SocketAddr = "198.51.100.77:9000".parse().expect("valid addr");
server.reconcile_public_addresses(Some(observed_ipv4), None);
assert_eq!(server.public_address(), observed_ipv4);
assert_eq!(server.secondary_address(), Some(default_ipv6));
let observed_ipv6: SocketAddr = "[2001:db8::77]:9000".parse().expect("valid addr");
server.reconcile_public_addresses(None, Some(observed_ipv6));
assert_eq!(server.public_address(), default_ipv4);
assert_eq!(server.secondary_address(), Some(observed_ipv6));
server.reconcile_public_addresses(None, None);
assert_eq!(server.public_address(), default_ipv4);
assert_eq!(server.secondary_address(), Some(default_ipv6));
}
#[tokio::test]
async fn test_connect_request_creates_session() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
assert!(response.proxy_public_address.is_some());
assert_eq!(server.session_count().await, 1);
}
#[tokio::test]
async fn test_duplicate_client_rejected() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
let response1 = server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(response1.status, 200);
let response2 = server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(response2.status, 409);
}
#[tokio::test]
async fn test_session_limit() {
let config = MasqueRelayConfig {
max_sessions: 2,
..Default::default()
};
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::bind_any();
for i in 1..=2 {
let response = server
.handle_connect_request(&request, client_addr(i))
.await
.unwrap();
assert_eq!(response.status, 200);
}
let response = server
.handle_connect_request(&request, client_addr(3))
.await
.unwrap();
assert_eq!(response.status, 503);
}
#[tokio::test]
async fn test_cleanup_task_stops_when_server_dropped() {
let config = MasqueRelayConfig {
cleanup_interval: Duration::from_millis(50),
..Default::default()
};
let server = Arc::new(MasqueRelayServer::new(config, test_addr(9000)));
let handle = MasqueRelayServer::spawn_cleanup_task(&server);
tokio::time::sleep(Duration::from_millis(80)).await;
drop(server);
tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("cleanup task should stop after server drop")
.expect("cleanup task should complete cleanly");
}
#[tokio::test]
async fn test_cleanup_task_reaps_expired_sessions() {
let config = MasqueRelayConfig {
cleanup_interval: Duration::from_millis(50),
session_config: RelaySessionConfig {
session_timeout: Duration::from_millis(10),
..Default::default()
},
..Default::default()
};
let server = Arc::new(MasqueRelayServer::new(config, test_addr(9000)));
let _handle = MasqueRelayServer::spawn_cleanup_task(&server);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
assert_eq!(server.session_count().await, 1);
tokio::time::sleep(Duration::from_millis(120)).await;
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_target_request_accepted() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let request = ConnectUdpRequest::target(test_addr(8080));
let response = server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(response.status, 200);
}
#[tokio::test]
async fn test_close_session() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(response.status, 200);
assert_eq!(server.session_count().await, 1);
let session_ids = server.active_session_ids().await;
assert_eq!(session_ids.len(), 1);
server.close_session(session_ids[0]).await.unwrap();
assert_eq!(server.session_count().await, 0);
assert!(
!server.client_to_session.read().await.contains_key(&client),
"client-to-session map should be cleared when closing a session"
);
}
#[tokio::test]
async fn test_close_session_by_client() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client)
.await
.unwrap();
assert_eq!(server.session_count().await, 1);
server.close_session_by_client(client).await.unwrap();
assert_eq!(server.session_count().await, 0);
}
#[tokio::test]
async fn test_server_stats() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let stats = server.stats();
assert_eq!(stats.current_active_sessions(), 0);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client_addr(1))
.await
.unwrap();
assert_eq!(stats.current_active_sessions(), 1);
assert_eq!(stats.sessions_created.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_handle_client_datagram_records_bytes_and_datagram_count() {
use crate::VarInt;
use crate::masque::{Datagram, UncompressedDatagram};
use bytes::Bytes;
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(42);
server
.handle_connect_request(&ConnectUdpRequest::bind_any(), client)
.await
.unwrap();
let target: SocketAddr = "203.0.113.77:4444".parse().unwrap();
let payload = Bytes::from_static(b"PROBE0123456789-relay-forwarding-check");
let payload_len = payload.len() as u64;
let uncompressed = UncompressedDatagram::new(VarInt::from_u32(2), target, payload.clone());
let datagram = Datagram::from(uncompressed);
let result = server
.handle_client_datagram(client, datagram, payload)
.await;
match result {
DatagramResult::Forward(out) => {
assert_eq!(
out.target, target,
"forwarded datagram must address the peer target"
);
}
other => panic!(
"expected DatagramResult::Forward, got {other:?} — relay forwarding is broken"
),
}
let stats = server.stats();
assert_eq!(
stats.total_bytes_relayed(),
payload_len,
"bytes_relayed must advance by the forwarded payload size (#164)"
);
assert_eq!(
stats.datagrams_forwarded.load(Ordering::Relaxed),
1,
"datagrams_forwarded must advance for each forwarded datagram (#164)"
);
}
#[tokio::test]
async fn test_get_session_info() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, test_addr(9000));
let client = client_addr(1);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, client)
.await
.unwrap();
let session_ids = server.active_session_ids().await;
let info = server.get_session_info(session_ids[0]).await.unwrap();
assert_eq!(info.client_address, Some(client));
assert_eq!(info.state, RelaySessionState::Active);
}
fn ipv4_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), port)
}
fn ipv6_addr(port: u16) -> SocketAddr {
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
port,
)
}
fn ipv4_client(id: u8) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, id)), 12345)
}
fn ipv6_client(id: u8) -> SocketAddr {
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, id.into())),
12345,
)
}
#[tokio::test]
async fn test_dual_stack_creation() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new_dual_stack(config, ipv4_addr(9000), ipv6_addr(9000));
assert!(server.supports_dual_stack());
assert!(server.secondary_address().is_some());
}
#[tokio::test]
async fn test_single_stack_no_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(!server.supports_dual_stack());
assert!(server.secondary_address().is_none());
}
#[tokio::test]
async fn test_can_bridge_same_version() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(server.can_bridge(ipv4_client(1), ipv4_addr(8080)).await);
}
#[tokio::test]
async fn test_can_bridge_different_version_without_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
assert!(!server.can_bridge(ipv4_client(1), ipv6_addr(8080)).await);
}
#[tokio::test]
async fn test_can_bridge_different_version_with_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new_dual_stack(config, ipv4_addr(9000), ipv6_addr(9000));
assert!(server.can_bridge(ipv4_client(1), ipv6_addr(8080)).await);
assert!(server.can_bridge(ipv6_client(1), ipv4_addr(8080)).await);
}
#[tokio::test]
async fn test_address_for_target_ipv4() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let addr = server.address_for_target(&ipv4_addr(8080));
assert!(addr.is_ipv4());
}
#[tokio::test]
async fn test_address_for_target_ipv6() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let addr = server.address_for_target(&ipv6_addr(8080));
assert!(addr.is_ipv6());
}
#[tokio::test]
async fn test_bridging_connect_request_rejected_without_dual_stack() {
let config = MasqueRelayConfig::default();
let server = MasqueRelayServer::new(config, ipv4_addr(9000));
let request = ConnectUdpRequest::target(ipv6_addr(8080));
let response = server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(response.status, 501);
}
#[tokio::test]
async fn test_ipv4_client_session() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(response.status, 200);
let public_addr = response.proxy_public_address.unwrap();
assert!(public_addr.is_ipv4());
}
#[tokio::test]
async fn test_ipv6_client_session() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
let response = server
.handle_connect_request(&request, ipv6_client(1))
.await
.unwrap();
assert_eq!(response.status, 200);
let public_addr = response.proxy_public_address.unwrap();
assert!(public_addr.is_ipv6());
}
#[tokio::test]
async fn test_bridged_connection_count() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
assert_eq!(server.bridged_connection_count(), 0);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
assert_eq!(server.bridged_connection_count(), 0);
}
#[tokio::test]
async fn test_session_bridging_flag() {
let config = MasqueRelayConfig::default();
let v4 = ipv4_addr(9000);
let v6 = ipv6_addr(9000);
let server = MasqueRelayServer::new_dual_stack(config, v4, v6);
let request = ConnectUdpRequest::bind_any();
server
.handle_connect_request(&request, ipv4_client(1))
.await
.unwrap();
let session_ids = server.active_session_ids().await;
let info = server.get_session_info(session_ids[0]).await.unwrap();
assert!(!info.is_bridging);
}
}