#![deny(missing_debug_implementations)]
#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(clippy::std_instead_of_core)]
#![deny(clippy::std_instead_of_alloc)]
#![no_std]
extern crate alloc;
#[cfg(any(feature = "std", test))]
extern crate std;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::net::SocketAddr;
use core::time::Duration;
use std::io::{Read, Write};
use turn_server_proto::types::prelude::DelayedTransmitBuild;
use turn_server_proto::types::transmit::TransmitBuild;
use turn_server_proto::types::AddressFamily;
use turn_server_proto::api::Transmit;
use turn_server_proto::server::TurnServer;
use turn_server_proto::types::stun::TransportType;
use turn_server_proto::types::Instant;
pub use turn_server_proto as proto;
pub use turn_server_proto::api;
use turn_server_proto::api::{
DelayedMessageOrChannelSend, SocketAllocateError, TurnServerApi, TurnServerPollRet,
};
use tracing::{info, trace, warn};
use rustls::{ServerConfig, ServerConnection};
#[derive(Debug)]
pub struct RustlsTurnServer {
server: TurnServer,
config: Arc<ServerConfig>,
clients: Vec<Client>,
}
#[derive(Debug)]
struct Client {
client_addr: SocketAddr,
tls: ServerConnection,
local_closed: bool,
peer_closed: bool,
}
impl RustlsTurnServer {
pub fn new(listen_addr: SocketAddr, realm: String, config: Arc<ServerConfig>) -> Self {
Self {
server: TurnServer::new(TransportType::Tcp, listen_addr, realm),
config,
clients: vec![],
}
}
}
impl TurnServerApi for RustlsTurnServer {
fn add_user(&mut self, username: String, password: String) {
self.server.add_user(username, password)
}
fn listen_address(&self) -> SocketAddr {
self.server.listen_address()
}
fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration) {
self.server.set_nonce_expiry_duration(expiry_duration)
}
#[tracing::instrument(
name = "turn_server_rustls_recv",
skip(self, transmit, now),
fields(
from = ?transmit.from,
data_len = transmit.data.as_ref().len()
)
)]
fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
&mut self,
transmit: Transmit<T>,
now: Instant,
) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>> {
let listen_address = self.listen_address();
if transmit.transport == TransportType::Tcp && transmit.to == listen_address {
trace!("receiving TLS data: {:x?}", transmit.data.as_ref());
let client = match self
.clients
.iter_mut()
.find(|client| client.client_addr == transmit.from)
{
Some(client) => client,
None => {
if transmit.data.as_ref().is_empty() {
return None;
}
let len = self.clients.len();
self.clients.push(Client {
client_addr: transmit.from,
tls: ServerConnection::new(self.config.clone()).unwrap(),
local_closed: false,
peer_closed: false,
});
info!("new connection from {}", transmit.from);
&mut self.clients[len]
}
};
let mut input = std::io::Cursor::new(transmit.data.as_ref());
let io_state = match client.tls.read_tls(&mut input) {
Ok(_written) => match client.tls.process_new_packets() {
Ok(io_state) => io_state,
Err(e) => {
warn!("Error processing incoming TLS: {e:?}");
return None;
}
},
Err(e) => {
warn!("Error receiving data: {e:?}");
return None;
}
};
if io_state.peer_has_closed() {
client.peer_closed = true;
if !client.local_closed {
client.tls.send_close_notify();
client.local_closed = true;
let mut out = vec![];
client.tls.write_tls(&mut out).unwrap();
let client_addr = client.client_addr;
info!("client {client_addr} TLS closed");
return Some(TransmitBuild::new(
DelayedMessageOrChannelSend::Owned(out),
TransportType::Tcp,
listen_address,
client_addr,
));
} else {
return None;
}
}
if io_state.plaintext_bytes_to_read() == 0 {
return None;
}
let mut vec = vec![0; 2048];
let n = match client.tls.reader().read(&mut vec) {
Ok(n) => n,
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
return None;
} else {
warn!("TLS error: {e:?}");
return None;
}
}
};
trace!("io_state: {io_state:?}, n: {n}");
vec.resize(n, 0);
let transmit = self.server.recv(
Transmit::new(vec, transmit.transport, transmit.from, transmit.to),
now,
)?;
if transmit.transport == TransportType::Tcp
&& transmit.from == listen_address
&& transmit.to == client.client_addr
{
let plaintext = transmit.data.build();
client.tls.writer().write_all(&plaintext).unwrap();
let mut out = vec![];
client.tls.write_tls(&mut out).unwrap();
Some(TransmitBuild::new(
DelayedMessageOrChannelSend::Owned(out),
TransportType::Tcp,
listen_address,
client.client_addr,
))
} else {
let transmit = transmit.build();
Some(TransmitBuild::new(
DelayedMessageOrChannelSend::Owned(transmit.data),
transmit.transport,
transmit.from,
transmit.to,
))
}
} else if let Some(transmit) = self.server.recv(transmit, now) {
if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
let Some(client) = self
.clients
.iter_mut()
.find(|client| transmit.to == client.client_addr)
else {
return Some(transmit);
};
let plaintext = transmit.data.build();
client.tls.writer().write_all(&plaintext).unwrap();
let mut out = vec![];
client.tls.write_tls(&mut out).unwrap();
Some(TransmitBuild::new(
DelayedMessageOrChannelSend::Owned(out),
TransportType::Tcp,
listen_address,
client.client_addr,
))
} else {
Some(transmit)
}
} else {
None
}
}
fn recv_icmp<T: AsRef<[u8]>>(
&mut self,
family: AddressFamily,
bytes: T,
now: Instant,
) -> Option<Transmit<Vec<u8>>> {
let transmit = self.server.recv_icmp(family, bytes, now)?;
let listen_address = self.listen_address();
if transmit.transport == TransportType::Tcp && transmit.from == listen_address {
let Some(client) = self
.clients
.iter_mut()
.find(|client| transmit.to == client.client_addr)
else {
return Some(transmit);
};
client.tls.writer().write_all(&transmit.data).unwrap();
let mut out = vec![];
client.tls.write_tls(&mut out).unwrap();
Some(Transmit::new(
out,
TransportType::Tcp,
listen_address,
client.client_addr,
))
} else {
Some(transmit)
}
}
fn poll(&mut self, now: Instant) -> TurnServerPollRet {
let protocol_ret = self.server.poll(now);
let mut have_pending = false;
for (idx, client) in self.clients.iter_mut().enumerate() {
trace!("client: {client:?}");
let io_state = match client.tls.process_new_packets() {
Ok(io_state) => io_state,
Err(e) => {
warn!("Error processing TLS: {e:?}");
continue;
}
};
trace!("{io_state:?}");
if io_state.tls_bytes_to_write() > 0 {
have_pending = true;
continue;
} else if !client.peer_closed && io_state.peer_has_closed() {
client.peer_closed = true;
if !client.local_closed {
client.tls.send_close_notify();
client.local_closed = true;
have_pending = true;
continue;
}
}
if client.local_closed && client.peer_closed && !client.tls.wants_write() {
let client = self.clients.remove(idx);
return TurnServerPollRet::TcpClose {
local_addr: self.server.listen_address(),
remote_addr: client.client_addr,
};
}
}
if let TurnServerPollRet::TcpClose {
local_addr,
remote_addr,
} = protocol_ret
{
let Some(client) = self
.clients
.iter_mut()
.find(|client| client.client_addr == remote_addr)
else {
return TurnServerPollRet::TcpClose {
local_addr,
remote_addr,
};
};
client.tls.send_close_notify();
client.local_closed = true;
return TurnServerPollRet::WaitUntil(now);
}
if have_pending {
return TurnServerPollRet::WaitUntil(now);
}
protocol_ret
}
fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>> {
let listen_address = self.listen_address();
while let Some(transmit) = self.server.poll_transmit(now) {
if let Some(client) = self
.clients
.iter_mut()
.find(|client| transmit.to == client.client_addr)
{
if transmit.data.is_empty() {
if !client.local_closed {
warn!("client {} closed", client.client_addr);
client.tls.send_close_notify();
client.local_closed = true;
}
} else {
client.tls.writer().write_all(&transmit.data).unwrap();
}
} else {
warn!("return transmit: {transmit:?}");
return Some(transmit);
};
}
for client in self.clients.iter_mut() {
trace!("client: {client:?}");
let client_addr = client.client_addr;
if !client.tls.wants_write() {
continue;
}
let mut vec = vec![];
let n = match client.tls.write_tls(&mut vec) {
Ok(n) => n,
Err(e) => {
warn!("error writing TLS: {e:?}");
continue;
}
};
vec.resize(n, 0);
warn!("return transmit: {vec:x?}");
return Some(Transmit::new(
vec,
TransportType::Tcp,
listen_address,
client_addr,
));
}
None
}
fn allocated_socket(
&mut self,
transport: TransportType,
local_addr: SocketAddr,
remote_addr: SocketAddr,
allocation_transport: TransportType,
family: AddressFamily,
socket_addr: Result<SocketAddr, SocketAllocateError>,
now: Instant,
) {
self.server.allocated_socket(
transport,
local_addr,
remote_addr,
allocation_transport,
family,
socket_addr,
now,
)
}
fn tcp_connected(
&mut self,
relayed_addr: SocketAddr,
peer_addr: SocketAddr,
listen_addr: SocketAddr,
client_addr: SocketAddr,
socket_addr: Result<SocketAddr, crate::api::TcpConnectError>,
now: Instant,
) {
self.server.tcp_connected(
relayed_addr,
peer_addr,
listen_addr,
client_addr,
socket_addr,
now,
)
}
}