use crate::error::OverlayError;
use crate::nat::config::RelayServerConfig;
use crate::nat::turn::{
build_control_msg, build_data_msg_tagged, decode_addr, derive_auth_key, encode_addr,
parse_and_verify_control, parse_data_payload_tagged, parse_msg, MsgType,
PEER_ADDR_V4_TAGGED_LEN,
};
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use rand::Rng;
use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
const DEFAULT_LIFETIME_SECS: u32 = 600;
const MAX_PACKET_SIZE: usize = 65536;
#[allow(dead_code, clippy::struct_field_names)]
struct Allocation {
client_addr: SocketAddr,
allocation_id: [u8; 16],
relay_socket: Arc<UdpSocket>,
relay_addr: SocketAddr,
permissions: Vec<IpAddr>,
lifetime_secs: u32,
refreshed_at: std::time::Instant,
relay_handle: tokio::task::JoinHandle<()>,
}
type AllocationTable = Arc<RwLock<HashMap<[u8; 16], Allocation>>>;
type ClientLookup = Arc<RwLock<HashMap<SocketAddr, [u8; 16]>>>;
pub struct RelayServer {
config: RelayServerConfig,
auth_key: [u8; 32],
shutdown: Arc<AtomicBool>,
}
impl RelayServer {
#[must_use]
pub fn new(config: &RelayServerConfig, auth_credential: &str) -> Self {
Self {
config: config.clone(),
auth_key: derive_auth_key(auth_credential),
shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub async fn start(&self) -> Result<(), OverlayError> {
let external_addr: SocketAddr = self
.config
.external_addr
.parse()
.map_err(|e| OverlayError::TurnRelay(format!("Invalid external addr: {e}")))?;
let listen_addr = if external_addr.is_ipv6() {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.config.listen_port)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), self.config.listen_port)
};
let socket = Arc::new(UdpSocket::bind(listen_addr).await.map_err(|e| {
OverlayError::TurnRelay(format!("Failed to bind relay server on {listen_addr}: {e}"))
})?);
info!(
listen = %listen_addr,
external = %self.config.external_addr,
max_sessions = self.config.max_sessions,
"Relay server started"
);
let allocations: AllocationTable = Arc::new(RwLock::new(HashMap::new()));
let client_lookup: ClientLookup = Arc::new(RwLock::new(HashMap::new()));
let auth_key = self.auth_key;
let max_sessions = self.config.max_sessions;
let shutdown = self.shutdown.clone();
let socket_clone = socket.clone();
tokio::spawn(async move {
#[allow(clippy::large_stack_arrays)]
let mut buf = [0u8; MAX_PACKET_SIZE];
loop {
if shutdown.load(Ordering::Relaxed) {
info!("Relay server shutting down");
break;
}
let recv_result = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_clone.recv_from(&mut buf),
)
.await;
let (n, from) = match recv_result {
Ok(Ok((n, from))) => (n, from),
Ok(Err(e)) => {
warn!(error = %e, "Relay server recv error");
continue;
}
Err(_) => continue, };
let packet = &buf[..n];
let Some(msg_type_byte) = packet.first() else {
continue;
};
let Some(msg_type) = MsgType::from_byte(*msg_type_byte) else {
continue;
};
match msg_type {
MsgType::AllocateReq => {
handle_allocate_req(
packet,
from,
&auth_key,
external_addr,
max_sessions,
&allocations,
&client_lookup,
&socket_clone,
)
.await;
}
MsgType::PermissionReq => {
handle_permission_req(packet, from, &auth_key, &allocations, &socket_clone)
.await;
}
MsgType::RefreshReq => {
handle_refresh_req(packet, from, &auth_key, &allocations, &socket_clone)
.await;
}
MsgType::Deallocate => {
handle_deallocate(packet, from, &auth_key, &allocations, &client_lookup)
.await;
}
MsgType::Data => {
handle_data(packet, from, &allocations, &client_lookup).await;
}
_ => {
debug!(msg_type = ?msg_type, from = %from, "Ignoring unexpected message type");
}
}
}
let allocs = allocations.write().await;
for (_, alloc) in allocs.iter() {
alloc.relay_handle.abort();
}
});
Ok(())
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
info!("Relay server shutdown signaled");
}
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn handle_allocate_req(
packet: &[u8],
from: SocketAddr,
auth_key: &[u8; 32],
external_addr: SocketAddr,
max_sessions: usize,
allocations: &AllocationTable,
client_lookup: &ClientLookup,
main_socket: &Arc<UdpSocket>,
) {
let Some((_msg_type, payload)) = parse_and_verify_control(packet, auth_key) else {
debug!(from = %from, "AllocateReq auth failed");
return;
};
{
let allocs = allocations.read().await;
if allocs.len() >= max_sessions {
let err_msg =
build_control_msg(MsgType::AllocateErr, b"max sessions reached", auth_key);
let _ = main_socket.send_to(&err_msg, from).await;
return;
}
}
{
let lookup = client_lookup.read().await;
if lookup.contains_key(&from) {
let err_msg =
build_control_msg(MsgType::AllocateErr, b"allocation already exists", auth_key);
let _ = main_socket.send_to(&err_msg, from).await;
return;
}
}
let username = if payload.is_empty() {
"unknown".to_string()
} else {
let ulen = payload[0] as usize;
if payload.len() > ulen {
String::from_utf8_lossy(&payload[1..=ulen]).to_string()
} else {
"unknown".to_string()
}
};
let relay_bind_addr = if external_addr.is_ipv6() {
"[::]:0"
} else {
"0.0.0.0:0"
};
let relay_socket = match UdpSocket::bind(relay_bind_addr).await {
Ok(s) => Arc::new(s),
Err(e) => {
warn!(error = %e, "Failed to bind relay socket for allocation");
let err_msg = build_control_msg(MsgType::AllocateErr, b"relay bind failed", auth_key);
let _ = main_socket.send_to(&err_msg, from).await;
return;
}
};
let relay_port = match relay_socket.local_addr() {
Ok(addr) => addr.port(),
Err(e) => {
warn!(error = %e, "Failed to get relay socket addr");
return;
}
};
let relay_addr = SocketAddr::new(external_addr.ip(), relay_port);
let mut allocation_id = [0u8; 16];
rand::rng().fill(&mut allocation_id[..]);
let relay_socket_clone = relay_socket.clone();
let main_socket_clone = main_socket.clone();
let alloc_table_clone = allocations.clone();
let alloc_id_copy = allocation_id;
let client_addr = from;
let relay_handle = tokio::spawn(async move {
#[allow(clippy::large_stack_arrays)]
let mut buf = [0u8; MAX_PACKET_SIZE];
loop {
match relay_socket_clone.recv_from(&mut buf).await {
Ok((n, peer_from)) => {
let permitted = {
let allocs = alloc_table_clone.read().await;
if let Some(alloc) = allocs.get(&alloc_id_copy) {
alloc.permissions.iter().any(|p| *p == peer_from.ip())
} else {
break; }
};
if !permitted {
continue;
}
let data_msg = build_data_msg_tagged(peer_from, &buf[..n]);
if let Err(e) = main_socket_clone.send_to(&data_msg, client_addr).await {
warn!(error = %e, "Failed to forward peer data to client");
}
}
Err(e) => {
warn!(error = %e, "Relay socket recv error");
break;
}
}
}
});
let allocation = Allocation {
client_addr: from,
allocation_id,
relay_socket,
relay_addr,
permissions: Vec::new(),
lifetime_secs: DEFAULT_LIFETIME_SECS,
refreshed_at: std::time::Instant::now(),
relay_handle,
};
{
let mut allocs = allocations.write().await;
allocs.insert(allocation_id, allocation);
}
{
let mut lookup = client_lookup.write().await;
lookup.insert(from, allocation_id);
}
let encoded_relay = encode_addr(relay_addr);
let mut resp_payload = Vec::with_capacity(encoded_relay.len() + 16 + 4);
resp_payload.extend_from_slice(&encoded_relay);
resp_payload.extend_from_slice(&allocation_id);
resp_payload.extend_from_slice(&DEFAULT_LIFETIME_SECS.to_be_bytes());
let resp = build_control_msg(MsgType::AllocateResp, &resp_payload, auth_key);
let _ = main_socket.send_to(&resp, from).await;
info!(
client = %from,
relay = %relay_addr,
username = %username,
"Relay allocation created"
);
}
async fn handle_permission_req(
packet: &[u8],
from: SocketAddr,
auth_key: &[u8; 32],
allocations: &AllocationTable,
main_socket: &Arc<UdpSocket>,
) {
let Some((_msg_type, payload)) = parse_and_verify_control(packet, auth_key) else {
debug!(from = %from, "PermissionReq auth failed");
return;
};
if payload.len() < 16 + PEER_ADDR_V4_TAGGED_LEN {
return;
}
let mut alloc_id = [0u8; 16];
alloc_id.copy_from_slice(&payload[..16]);
let Some((peer_addr, _)) = decode_addr(&payload[16..]) else {
return;
};
let peer_ip = peer_addr.ip();
{
let mut allocs = allocations.write().await;
if let Some(alloc) = allocs.get_mut(&alloc_id) {
if alloc.client_addr != from {
debug!(from = %from, "PermissionReq from non-owner");
return;
}
if !alloc.permissions.contains(&peer_ip) {
alloc.permissions.push(peer_ip);
}
} else {
debug!(from = %from, "PermissionReq for unknown allocation");
return;
}
}
let resp = build_control_msg(MsgType::PermissionResp, &[], auth_key);
let _ = main_socket.send_to(&resp, from).await;
debug!(from = %from, peer = %peer_addr, "Permission added");
}
async fn handle_refresh_req(
packet: &[u8],
from: SocketAddr,
auth_key: &[u8; 32],
allocations: &AllocationTable,
main_socket: &Arc<UdpSocket>,
) {
let Some((_msg_type, payload)) = parse_and_verify_control(packet, auth_key) else {
debug!(from = %from, "RefreshReq auth failed");
return;
};
if payload.len() < 16 + 4 {
return;
}
let mut alloc_id = [0u8; 16];
alloc_id.copy_from_slice(&payload[..16]);
let lifetime = u32::from_be_bytes([payload[16], payload[17], payload[18], payload[19]]);
{
let mut allocs = allocations.write().await;
if let Some(alloc) = allocs.get_mut(&alloc_id) {
if alloc.client_addr != from {
debug!(from = %from, "RefreshReq from non-owner");
return;
}
alloc.lifetime_secs = lifetime;
alloc.refreshed_at = std::time::Instant::now();
} else {
debug!(from = %from, "RefreshReq for unknown allocation");
return;
}
}
let resp = build_control_msg(MsgType::RefreshResp, &lifetime.to_be_bytes(), auth_key);
let _ = main_socket.send_to(&resp, from).await;
debug!(from = %from, lifetime = lifetime, "Allocation refreshed");
}
async fn handle_deallocate(
packet: &[u8],
from: SocketAddr,
auth_key: &[u8; 32],
allocations: &AllocationTable,
client_lookup: &ClientLookup,
) {
let Some((_msg_type, payload)) = parse_and_verify_control(packet, auth_key) else {
debug!(from = %from, "Deallocate auth failed");
return;
};
if payload.len() < 16 {
return;
}
let mut alloc_id = [0u8; 16];
alloc_id.copy_from_slice(&payload[..16]);
let removed = {
let mut allocs = allocations.write().await;
if let Some(alloc) = allocs.get(&alloc_id) {
if alloc.client_addr != from {
debug!(from = %from, "Deallocate from non-owner");
return;
}
}
allocs.remove(&alloc_id)
};
if let Some(alloc) = removed {
alloc.relay_handle.abort();
let mut lookup = client_lookup.write().await;
lookup.remove(&from);
info!(client = %from, "Allocation deallocated");
}
}
async fn handle_data(
packet: &[u8],
from: SocketAddr,
allocations: &AllocationTable,
client_lookup: &ClientLookup,
) {
let alloc_id = {
let lookup = client_lookup.read().await;
match lookup.get(&from) {
Some(id) => *id,
None => return, }
};
let Some((MsgType::Data, payload)) = parse_msg(packet) else {
return;
};
let Some((peer_addr, raw_data)) = parse_data_payload_tagged(payload) else {
return;
};
let allocs = allocations.read().await;
if let Some(alloc) = allocs.get(&alloc_id) {
if !alloc.permissions.iter().any(|p| *p == peer_addr.ip()) {
return;
}
if let Err(e) = alloc.relay_socket.send_to(raw_data, peer_addr).await {
warn!(error = %e, dest = %peer_addr, "Failed to relay data to peer");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nat::config::TurnServerConfig;
#[test]
fn test_relay_server_new() {
let config = RelayServerConfig {
listen_port: 3478,
external_addr: "1.2.3.4:3478".to_string(),
max_sessions: 50,
};
let server = RelayServer::new(&config, "test_credential");
assert_eq!(server.config.listen_port, 3478);
assert_eq!(server.config.max_sessions, 50);
}
#[test]
fn test_relay_server_auth_key_derivation() {
let config = RelayServerConfig {
listen_port: 3478,
external_addr: "1.2.3.4:3478".to_string(),
max_sessions: 100,
};
let server = RelayServer::new(&config, "shared_secret");
let expected_key = derive_auth_key("shared_secret");
assert_eq!(server.auth_key, expected_key);
}
#[test]
fn test_relay_server_shutdown_flag() {
let config = RelayServerConfig {
listen_port: 3478,
external_addr: "1.2.3.4:3478".to_string(),
max_sessions: 100,
};
let server = RelayServer::new(&config, "test");
assert!(!server.shutdown.load(Ordering::Relaxed));
server.shutdown();
assert!(server.shutdown.load(Ordering::Relaxed));
}
#[tokio::test]
async fn test_relay_server_allocate_roundtrip() {
let _config = RelayServerConfig {
listen_port: 0, external_addr: "127.0.0.1:0".to_string(),
max_sessions: 10,
};
let credential = "test_secret";
let listen_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let listen_port = listen_socket.local_addr().unwrap().port();
drop(listen_socket);
let real_config = RelayServerConfig {
listen_port,
external_addr: format!("127.0.0.1:{listen_port}"),
max_sessions: 10,
};
let server = RelayServer::new(&real_config, credential);
server.start().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let client_config = TurnServerConfig {
address: format!("127.0.0.1:{listen_port}"),
username: "testuser".to_string(),
credential: credential.to_string(),
region: None,
};
let mut client = crate::nat::turn::RelayClient::new(&client_config).unwrap();
let result = client.allocate().await;
assert!(result.is_ok(), "Allocation failed: {result:?}");
assert!(client.is_active());
let _ = client.deallocate().await;
server.shutdown();
}
#[test]
fn test_relay_server_new_ipv6_external() {
let config = RelayServerConfig {
listen_port: 3478,
external_addr: "[::1]:3478".to_string(),
max_sessions: 50,
};
let server = RelayServer::new(&config, "test_credential");
assert_eq!(server.config.listen_port, 3478);
assert_eq!(server.config.external_addr, "[::1]:3478");
}
#[tokio::test]
async fn test_relay_server_start_ipv6() {
let listen_socket = UdpSocket::bind("[::1]:0").await.unwrap();
let listen_port = listen_socket.local_addr().unwrap().port();
drop(listen_socket);
let config = RelayServerConfig {
listen_port,
external_addr: format!("[::1]:{listen_port}"),
max_sessions: 10,
};
let server = RelayServer::new(&config, "ipv6_secret");
let result = server.start().await;
assert!(result.is_ok(), "IPv6 relay server should start: {result:?}");
server.shutdown();
}
#[tokio::test]
async fn test_relay_server_allocate_roundtrip_ipv6() {
let credential = "ipv6_test_secret";
let listen_socket = UdpSocket::bind("[::1]:0").await.unwrap();
let listen_port = listen_socket.local_addr().unwrap().port();
drop(listen_socket);
let real_config = RelayServerConfig {
listen_port,
external_addr: format!("[::1]:{listen_port}"),
max_sessions: 10,
};
let server = RelayServer::new(&real_config, credential);
server.start().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let client_config = TurnServerConfig {
address: format!("[::1]:{listen_port}"),
username: "testuser".to_string(),
credential: credential.to_string(),
region: None,
};
let mut client = crate::nat::turn::RelayClient::new(&client_config).unwrap();
let result = client.allocate().await;
assert!(result.is_ok(), "IPv6 allocation failed: {result:?}");
assert!(client.is_active());
let relay_addr = result.unwrap();
assert!(relay_addr.is_ipv6(), "Relay address should be IPv6");
let _ = client.deallocate().await;
server.shutdown();
}
}