use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime};
use tokio::net::UdpSocket;
use crate::VarInt;
use crate::masque::{
Capsule, CompressionAck, CompressionAssign, CompressionClose, ContextError, ContextManager,
Datagram,
};
use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct RelaySessionConfig {
pub bandwidth_limit: u64,
pub session_timeout: Duration,
pub max_contexts: usize,
pub datagram_buffer_size: usize,
}
impl Default for RelaySessionConfig {
fn default() -> Self {
Self {
bandwidth_limit: 1_048_576, session_timeout: Duration::from_secs(300), max_contexts: 100,
datagram_buffer_size: 65536,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelaySessionState {
Pending,
Active,
Closing,
Closed,
Error,
}
#[derive(Debug, Default)]
pub struct RelaySessionStats {
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub datagrams_forwarded: AtomicU64,
pub capsules_processed: AtomicU64,
pub contexts_registered: AtomicU64,
}
impl RelaySessionStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_bytes_sent(&self, bytes: u64) {
self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_bytes_received(&self, bytes: u64) {
self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_datagram(&self) {
self.datagrams_forwarded.fetch_add(1, Ordering::Relaxed);
}
pub fn record_capsule(&self) {
self.capsules_processed.fetch_add(1, Ordering::Relaxed);
}
pub fn total_bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
pub fn total_bytes_received(&self) -> u64 {
self.bytes_received.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct RelaySession {
session_id: u64,
config: RelaySessionConfig,
state: RelaySessionState,
public_address: SocketAddr,
client_address: Option<SocketAddr>,
context_manager: ContextManager,
target_to_context: HashMap<SocketAddr, VarInt>,
created_at: Instant,
last_activity: Instant,
stats: Arc<RelaySessionStats>,
is_bridging: bool,
udp_socket: Option<Arc<UdpSocket>>,
bytes_in_window: AtomicU64,
window_start_ms: AtomicU64,
}
impl RelaySession {
pub fn new(session_id: u64, config: RelaySessionConfig, public_address: SocketAddr) -> Self {
let now = Instant::now();
Self {
session_id,
config,
state: RelaySessionState::Pending,
public_address,
client_address: None,
context_manager: ContextManager::new(false), target_to_context: HashMap::new(),
created_at: now,
last_activity: now,
stats: Arc::new(RelaySessionStats::new()),
is_bridging: false,
udp_socket: None,
bytes_in_window: AtomicU64::new(0),
window_start_ms: AtomicU64::new(now_ms()),
}
}
pub fn session_id(&self) -> u64 {
self.session_id
}
pub fn state(&self) -> RelaySessionState {
self.state
}
pub fn public_address(&self) -> SocketAddr {
self.public_address
}
pub fn set_client_address(&mut self, addr: SocketAddr) {
self.client_address = Some(addr);
}
pub fn client_address(&self) -> Option<SocketAddr> {
self.client_address
}
pub fn stats(&self) -> Arc<RelaySessionStats> {
Arc::clone(&self.stats)
}
pub fn duration(&self) -> Duration {
self.created_at.elapsed()
}
pub fn is_timed_out(&self) -> bool {
self.last_activity.elapsed() > self.config.session_timeout
}
pub fn is_active(&self) -> bool {
self.state == RelaySessionState::Active
}
pub fn config(&self) -> &RelaySessionConfig {
&self.config
}
pub fn set_udp_socket(&mut self, socket: Arc<UdpSocket>) {
self.udp_socket = Some(socket);
}
pub fn udp_socket(&self) -> Option<&Arc<UdpSocket>> {
self.udp_socket.as_ref()
}
pub fn set_public_address(&mut self, addr: SocketAddr) {
self.public_address = addr;
}
pub fn set_bridging(&mut self, bridging: bool) {
self.is_bridging = bridging;
}
pub fn is_bridging(&self) -> bool {
self.is_bridging
}
pub fn check_rate_limit(&self, bytes: usize) -> bool {
if self.config.bandwidth_limit == 0 {
return true;
}
let now = now_ms();
let window_start = self.window_start_ms.load(Ordering::Relaxed);
if now.saturating_sub(window_start) >= 1000 {
self.window_start_ms.store(now, Ordering::Relaxed);
self.bytes_in_window.store(bytes as u64, Ordering::Relaxed);
return bytes as u64 <= self.config.bandwidth_limit;
}
let current = self
.bytes_in_window
.fetch_add(bytes as u64, Ordering::Relaxed);
if current + bytes as u64 > self.config.bandwidth_limit {
self.bytes_in_window
.fetch_sub(bytes as u64, Ordering::Relaxed);
return false;
}
true
}
pub fn activate(&mut self) -> RelayResult<()> {
match self.state {
RelaySessionState::Pending => {
self.state = RelaySessionState::Active;
self.last_activity = Instant::now();
Ok(())
}
_ => Err(RelayError::SessionError {
session_id: Some(self.session_id as u32),
kind: SessionErrorKind::InvalidState {
current_state: format!("{:?}", self.state),
expected_state: "Pending".into(),
},
}),
}
}
pub fn handle_capsule(&mut self, capsule: Capsule) -> RelayResult<Option<Capsule>> {
if !self.is_active() {
return Err(RelayError::SessionError {
session_id: Some(self.session_id as u32),
kind: SessionErrorKind::InvalidState {
current_state: format!("{:?}", self.state),
expected_state: "Active".into(),
},
});
}
self.last_activity = Instant::now();
self.stats.record_capsule();
match capsule {
Capsule::CompressionAssign(assign) => self.handle_compression_assign(assign),
Capsule::CompressionAck(ack) => self.handle_compression_ack(ack),
Capsule::CompressionClose(close) => self.handle_compression_close(close),
Capsule::Unknown { capsule_type, .. } => {
tracing::debug!(
session_id = self.session_id,
capsule_type = capsule_type.into_inner(),
"Ignoring unknown capsule type"
);
Ok(None)
}
}
}
fn handle_compression_assign(
&mut self,
assign: CompressionAssign,
) -> RelayResult<Option<Capsule>> {
if self.context_manager.active_count() >= self.config.max_contexts {
return Ok(Some(Capsule::CompressionClose(CompressionClose::new(
assign.context_id,
))));
}
let target = assign.target();
if let Some(t) = target {
if self.target_to_context.contains_key(&t) {
return Ok(Some(Capsule::CompressionClose(CompressionClose::new(
assign.context_id,
))));
}
}
let result = self
.context_manager
.register_remote(assign.context_id, target)
.map(|_| {
if let Some(t) = target {
self.target_to_context.insert(t, assign.context_id);
}
});
match result {
Ok(_) => {
self.stats
.contexts_registered
.fetch_add(1, Ordering::Relaxed);
Ok(Some(Capsule::CompressionAck(CompressionAck::new(
assign.context_id,
))))
}
Err(e) => {
tracing::warn!(
session_id = self.session_id,
context_id = assign.context_id.into_inner(),
error = %e,
"Failed to register context"
);
Ok(Some(Capsule::CompressionClose(CompressionClose::new(
assign.context_id,
))))
}
}
}
fn handle_compression_ack(&mut self, ack: CompressionAck) -> RelayResult<Option<Capsule>> {
match self.context_manager.handle_ack(ack.context_id) {
Ok(_) => Ok(None),
Err(e) => {
tracing::warn!(
session_id = self.session_id,
context_id = ack.context_id.into_inner(),
error = %e,
"Unexpected ACK for unknown context"
);
Ok(None)
}
}
}
fn handle_compression_close(
&mut self,
close: CompressionClose,
) -> RelayResult<Option<Capsule>> {
if let Some(target) = self.context_manager.get_target(close.context_id) {
self.target_to_context.remove(&target);
}
match self.context_manager.close(close.context_id) {
Ok(_) | Err(ContextError::UnknownContext) => Ok(None),
Err(e) => {
tracing::warn!(
session_id = self.session_id,
context_id = close.context_id.into_inner(),
error = %e,
"Error closing context"
);
Ok(None)
}
}
}
pub fn resolve_target(&self, datagram: &Datagram) -> Option<SocketAddr> {
match datagram {
Datagram::Compressed(d) => self.context_manager.get_target(d.context_id),
Datagram::Uncompressed(d) => Some(d.target),
}
}
pub fn context_for_target(&mut self, target: SocketAddr) -> RelayResult<VarInt> {
if let Some(&ctx_id) = self.target_to_context.get(&target) {
return Ok(ctx_id);
}
let ctx_id =
self.context_manager
.allocate_local()
.map_err(|_| RelayError::ResourceExhausted {
resource_type: "contexts".into(),
current_usage: self.context_manager.active_count() as u64,
limit: self.config.max_contexts as u64,
})?;
self.context_manager
.register_compressed(ctx_id, target)
.map_err(|_| RelayError::SessionError {
session_id: Some(self.session_id as u32),
kind: SessionErrorKind::InvalidState {
current_state: "context registration failed".into(),
expected_state: "registered".into(),
},
})?;
self.target_to_context.insert(target, ctx_id);
Ok(ctx_id)
}
pub fn create_assign_capsule(&self, ctx_id: VarInt, target: SocketAddr) -> Capsule {
let assign = match target {
SocketAddr::V4(v4) => CompressionAssign::compressed_v4(ctx_id, *v4.ip(), v4.port()),
SocketAddr::V6(v6) => CompressionAssign::compressed_v6(ctx_id, *v6.ip(), v6.port()),
};
Capsule::CompressionAssign(assign)
}
pub fn record_bandwidth(&self, bytes: u64) -> RelayResult<()> {
let total = self.stats.total_bytes_sent() + self.stats.total_bytes_received();
let duration = self.duration().as_secs_f64();
if duration > 0.0 {
let rate = total as f64 / duration;
if rate > self.config.bandwidth_limit as f64 {
return Err(RelayError::SessionError {
session_id: Some(self.session_id as u32),
kind: SessionErrorKind::BandwidthExceeded {
used: rate as u64,
limit: self.config.bandwidth_limit,
},
});
}
}
self.stats.record_bytes_sent(bytes);
self.stats.record_datagram();
Ok(())
}
pub fn close(&mut self) {
match self.state {
RelaySessionState::Closed | RelaySessionState::Error => {}
_ => {
self.state = RelaySessionState::Closing;
self.target_to_context.clear();
self.state = RelaySessionState::Closed;
}
}
}
pub fn set_error(&mut self) {
self.state = RelaySessionState::Error;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), port)
}
#[test]
fn test_session_creation() {
let config = RelaySessionConfig::default();
let public_addr = test_addr(9000);
let session = RelaySession::new(1, config, public_addr);
assert_eq!(session.session_id(), 1);
assert_eq!(session.state(), RelaySessionState::Pending);
assert_eq!(session.public_address(), public_addr);
assert!(!session.is_active());
}
#[test]
fn test_session_activation() {
let config = RelaySessionConfig::default();
let session_id = 1;
let mut session = RelaySession::new(session_id, config, test_addr(9000));
assert!(session.activate().is_ok());
assert!(session.is_active());
assert_eq!(session.state(), RelaySessionState::Active);
}
#[test]
fn test_session_activation_from_wrong_state() {
let config = RelaySessionConfig::default();
let mut session = RelaySession::new(1, config, test_addr(9000));
session.activate().unwrap();
assert!(session.activate().is_err());
}
#[test]
fn test_handle_compression_assign() {
let config = RelaySessionConfig::default();
let mut session = RelaySession::new(1, config, test_addr(9000));
session.activate().unwrap();
let assign = CompressionAssign::compressed_v4(
VarInt::from_u32(2), Ipv4Addr::new(192, 168, 1, 100),
8080,
);
let capsule = Capsule::CompressionAssign(assign);
let response = session.handle_capsule(capsule).unwrap();
match response {
Some(Capsule::CompressionAck(ack)) => {
assert_eq!(ack.context_id, VarInt::from_u32(2));
}
_ => panic!("Expected CompressionAck"),
}
}
#[test]
fn test_context_limit() {
let config = RelaySessionConfig {
max_contexts: 2,
..Default::default()
};
let mut session = RelaySession::new(1, config, test_addr(9000));
session.activate().unwrap();
for i in 0..2 {
let assign = CompressionAssign::compressed_v4(
VarInt::from_u32((i + 1) * 2), Ipv4Addr::new(192, 168, 1, i as u8),
8080 + i as u16,
);
let capsule = Capsule::CompressionAssign(assign);
let response = session.handle_capsule(capsule).unwrap();
assert!(matches!(response, Some(Capsule::CompressionAck(_))));
}
let assign = CompressionAssign::compressed_v4(
VarInt::from_u32(6),
Ipv4Addr::new(192, 168, 1, 3),
8083,
);
let capsule = Capsule::CompressionAssign(assign);
let response = session.handle_capsule(capsule).unwrap();
assert!(matches!(response, Some(Capsule::CompressionClose(_))));
}
#[test]
fn test_session_close() {
let config = RelaySessionConfig::default();
let mut session = RelaySession::new(1, config, test_addr(9000));
session.activate().unwrap();
session.close();
assert_eq!(session.state(), RelaySessionState::Closed);
assert!(!session.is_active());
}
#[test]
fn test_session_stats() {
let config = RelaySessionConfig::default();
let session = RelaySession::new(1, config, test_addr(9000));
session.stats.record_bytes_sent(100);
session.stats.record_bytes_received(50);
session.stats.record_datagram();
assert_eq!(session.stats.total_bytes_sent(), 100);
assert_eq!(session.stats.total_bytes_received(), 50);
assert_eq!(session.stats.datagrams_forwarded.load(Ordering::Relaxed), 1);
}
#[test]
fn test_duplicate_target_rejected() {
let config = RelaySessionConfig::default();
let mut session = RelaySession::new(1, config, test_addr(9000));
session.activate().unwrap();
let target = Ipv4Addr::new(192, 168, 1, 100);
let port = 8080u16;
let assign1 = CompressionAssign::compressed_v4(VarInt::from_u32(2), target, port);
let response1 = session
.handle_capsule(Capsule::CompressionAssign(assign1))
.unwrap();
assert!(matches!(response1, Some(Capsule::CompressionAck(_))));
let assign2 = CompressionAssign::compressed_v4(VarInt::from_u32(4), target, port);
let response2 = session
.handle_capsule(Capsule::CompressionAssign(assign2))
.unwrap();
assert!(matches!(response2, Some(Capsule::CompressionClose(_))));
}
#[test]
fn test_rate_limit_allows_within_limit() {
let config = RelaySessionConfig {
bandwidth_limit: 1000,
..Default::default()
};
let session = RelaySession::new(1, config, test_addr(9000));
assert!(session.check_rate_limit(500));
assert!(session.check_rate_limit(400));
assert!(session.check_rate_limit(100));
}
#[test]
fn test_rate_limit_rejects_over_limit() {
let config = RelaySessionConfig {
bandwidth_limit: 1000,
..Default::default()
};
let session = RelaySession::new(1, config, test_addr(9000));
assert!(session.check_rate_limit(900));
assert!(!session.check_rate_limit(200));
}
#[test]
fn test_rate_limit_zero_means_unlimited() {
let config = RelaySessionConfig {
bandwidth_limit: 0,
..Default::default()
};
let session = RelaySession::new(1, config, test_addr(9000));
assert!(session.check_rate_limit(999_999_999));
}
}