use alloc::string::String;
use alloc::vec::Vec;
use core::{net::SocketAddr, time::Duration};
pub use stun_proto::agent::Transmit;
use turn_types::prelude::DelayedTransmitBuild;
use turn_types::stun::{attribute::ErrorCode, TransportType};
use turn_types::transmit::{DelayedChannel, DelayedMessage, TransmitBuild};
use turn_types::AddressFamily;
use turn_types::Instant;
pub trait TurnServerApi: Send + core::fmt::Debug {
fn add_user(&mut self, username: String, password: String);
fn listen_address(&self) -> SocketAddr;
fn set_nonce_expiry_duration(&mut self, expiry_duration: Duration);
fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
&mut self,
transmit: Transmit<T>,
now: Instant,
) -> Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>;
fn recv_icmp<T: AsRef<[u8]>>(
&mut self,
family: AddressFamily,
bytes: T,
now: Instant,
) -> Option<Transmit<Vec<u8>>>;
fn poll(&mut self, now: Instant) -> TurnServerPollRet;
fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Vec<u8>>>;
#[allow(clippy::too_many_arguments)]
fn allocated_socket(
&mut self,
transport: TransportType,
listen_addr: SocketAddr,
client_addr: SocketAddr,
allocation_transport: TransportType,
family: AddressFamily,
socket_addr: Result<SocketAddr, SocketAllocateError>,
now: Instant,
);
fn tcp_connected(
&mut self,
relayed_addr: SocketAddr,
peer_addr: SocketAddr,
listen_addr: SocketAddr,
client_addr: SocketAddr,
socket_addr: Result<SocketAddr, TcpConnectError>,
now: Instant,
);
}
#[derive(Debug)]
pub enum TurnServerPollRet {
WaitUntil(Instant),
AllocateSocket {
transport: TransportType,
listen_addr: SocketAddr,
client_addr: SocketAddr,
allocation_transport: TransportType,
family: AddressFamily,
},
TcpConnect {
relayed_addr: SocketAddr,
peer_addr: SocketAddr,
listen_addr: SocketAddr,
client_addr: SocketAddr,
},
TcpClose {
local_addr: SocketAddr,
remote_addr: SocketAddr,
},
SocketClose {
transport: TransportType,
listen_addr: SocketAddr,
},
}
#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
pub enum SocketAllocateError {
#[error("The address family is not supported.")]
AddressFamilyNotSupported,
#[error("The server does not have the capacity to handle this request.")]
InsufficientCapacity,
}
impl SocketAllocateError {
pub fn into_error_code(self) -> u16 {
match self {
Self::AddressFamilyNotSupported => ErrorCode::ADDRESS_FAMILY_NOT_SUPPORTED,
Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
}
}
}
#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
pub enum TcpConnectError {
#[error("The server does not have the capacity to handle this request.")]
InsufficientCapacity,
#[error("Connection is forbidden by local policy.")]
Forbidden,
#[error("Timed out attempting to connect to the specifid peer.")]
TimedOut,
#[error("Failed for any other unspecified reason.")]
Failure,
}
impl TcpConnectError {
pub fn into_error_code(self) -> u16 {
match self {
Self::InsufficientCapacity => ErrorCode::INSUFFICIENT_CAPACITY,
Self::Forbidden => ErrorCode::FORBIDDEN,
Self::TimedOut | Self::Failure => ErrorCode::CONNECTION_TIMEOUT_OR_FAILURE,
}
}
}
#[derive(Debug)]
pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
Message(DelayedMessage<T>),
Channel(DelayedChannel<T>),
Owned(Vec<u8>),
Range(T, core::ops::Range<usize>),
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
fn len(&self) -> usize {
match self {
Self::Message(msg) => msg.len(),
Self::Channel(channel) => channel.len(),
Self::Owned(v) => v.len(),
Self::Range(_data, range) => range.end - range.start,
}
}
fn build(self) -> Vec<u8> {
match self {
Self::Message(msg) => msg.build(),
Self::Channel(channel) => channel.build(),
Self::Owned(v) => v,
Self::Range(data, range) => data.as_ref()[range.start..range.end].to_vec(),
}
}
fn is_empty(&self) -> bool {
match self {
Self::Message(msg) => msg.is_empty(),
Self::Channel(channel) => channel.is_empty(),
Self::Owned(v) => v.is_empty(),
Self::Range(_data, range) => range.end == range.start,
}
}
fn write_into(self, data: &mut [u8]) -> usize {
match self {
Self::Message(msg) => msg.write_into(data),
Self::Channel(channel) => channel.write_into(data),
Self::Owned(v) => v.write_into(data),
Self::Range(src, range) => {
data.copy_from_slice(&src.as_ref()[range.start..range.end]);
range.end - range.start
}
}
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use turn_types::attribute::Data as AData;
use turn_types::attribute::XorPeerAddress;
use turn_types::channel::ChannelData;
use turn_types::stun::message::Message;
use super::*;
fn generate_addresses() -> (SocketAddr, SocketAddr) {
(
"192.168.0.1:1000".parse().unwrap(),
"10.0.0.2:2000".parse().unwrap(),
)
}
#[test]
fn test_delayed_message() {
let (local_addr, remote_addr) = generate_addresses();
let data = [5; 5];
let peer_addr = "127.0.0.1:1".parse().unwrap();
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
let msg = Message::from_bytes(&out.data).unwrap();
let addr = msg.attribute::<XorPeerAddress>().unwrap();
assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
let out_data = msg.attribute::<AData>().unwrap();
assert_eq!(out_data.data(), data.as_ref());
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
let msg = Message::from_bytes(&out2).unwrap();
let addr = msg.attribute::<XorPeerAddress>().unwrap();
assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
let out_data = msg.attribute::<AData>().unwrap();
assert_eq!(out_data.data(), data.as_ref());
}
#[test]
fn test_delayed_channel() {
let (local_addr, remote_addr) = generate_addresses();
let data = [5; 5];
let channel_id = 0x4567;
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
let channel = ChannelData::parse(&out.data).unwrap();
assert_eq!(channel.id(), channel_id);
assert_eq!(channel.data(), data.as_ref());
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
assert_eq!(len, out2.len());
let channel = ChannelData::parse(&out2).unwrap();
assert_eq!(channel.id(), channel_id);
assert_eq!(channel.data(), data.as_ref());
}
#[test]
fn test_delayed_owned() {
let (local_addr, remote_addr) = generate_addresses();
let data = vec![7; 7];
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
assert_eq!(data, out.data);
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::<Vec<u8>>::Owned(data.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
assert_eq!(len, out2.len());
assert_eq!(data, out2);
}
#[test]
fn test_delayed_range() {
let (local_addr, remote_addr) = generate_addresses();
let data = vec![7; 7];
let range = 2..6;
const LEN: usize = 4;
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
let len = transmit.data.len();
assert_eq!(len, LEN);
let out = transmit.build();
assert_eq!(len, out.data.len());
assert_eq!(data[range.start..range.end], out.data);
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Range(data.clone(), range.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
assert_eq!(len, out2.len());
assert_eq!(data[range.start..range.end], out2);
}
}