use alloc::vec::Vec;
use core::net::{IpAddr, SocketAddr};
use core::ops::Range;
use core::time::Duration;
use stun_proto::agent::StunAgentBuilder;
use stun_proto::auth::Feature;
use turn_types::prelude::DelayedTransmitBuild;
use turn_types::stun::message::IntegrityAlgorithm;
pub use turn_types::transmit::TransmitBuild;
use turn_types::transmit::{DelayedChannel, DelayedMessage};
pub use stun_proto::agent::Transmit;
pub use stun_proto::types::data::Data;
use stun_proto::types::TransportType;
use stun_proto::Instant;
use turn_types::{AddressFamily, TurnCredentials};
pub trait TurnClientApi: core::fmt::Debug + Send {
fn transport(&self) -> TransportType;
fn local_addr(&self) -> SocketAddr;
fn remote_addr(&self) -> SocketAddr;
fn relayed_addresses(&self) -> impl Iterator<Item = (TransportType, SocketAddr)> + '_;
fn permissions(
&self,
transport: TransportType,
relayed: SocketAddr,
) -> impl Iterator<Item = IpAddr> + '_;
fn delete(&mut self, now: Instant) -> Result<(), DeleteError>;
fn create_permission(
&mut self,
transport: TransportType,
peer_addr: IpAddr,
now: Instant,
) -> Result<(), CreatePermissionError>;
fn have_permission(&self, transport: TransportType, to: IpAddr) -> bool;
fn bind_channel(
&mut self,
transport: TransportType,
peer_addr: SocketAddr,
now: Instant,
) -> Result<(), BindChannelError>;
fn tcp_connect(&mut self, peer_addr: SocketAddr, now: Instant) -> Result<(), TcpConnectError>;
fn allocated_tcp_socket(
&mut self,
id: u32,
five_tuple: Socket5Tuple,
peer_addr: SocketAddr,
local_addr: Option<SocketAddr>,
now: Instant,
) -> Result<(), TcpAllocateError>;
fn tcp_closed(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, now: Instant);
fn send_to<T: AsRef<[u8]> + core::fmt::Debug>(
&mut self,
transport: TransportType,
to: SocketAddr,
data: T,
now: Instant,
) -> Result<Option<TransmitBuild<DelayedMessageOrChannelSend<T>>>, SendError>;
fn recv<T: AsRef<[u8]> + core::fmt::Debug>(
&mut self,
transmit: Transmit<T>,
now: Instant,
) -> TurnRecvRet<T>;
fn poll_recv(&mut self, now: Instant) -> Option<TurnPeerData<Vec<u8>>>;
fn poll(&mut self, now: Instant) -> TurnPollRet;
fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<Data<'static>>>;
fn poll_event(&mut self) -> Option<TurnEvent>;
fn protocol_error(&mut self);
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TurnConfig {
allocation_transport: TransportType,
address_families: smallvec::SmallVec<[AddressFamily; 2]>,
credentials: TurnCredentials,
supported_integrity: smallvec::SmallVec<[IntegrityAlgorithm; 2]>,
anonymous_username: Feature,
rto: Option<RequestRto>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct RequestRto {
initial: Duration,
max: Duration,
retransmits: u32,
final_retransmit_timeout: Duration,
}
impl RequestRto {}
impl TurnConfig {
pub fn new(credentials: TurnCredentials) -> Self {
Self {
allocation_transport: TransportType::Udp,
address_families: smallvec::smallvec![AddressFamily::IPV4],
credentials,
supported_integrity: smallvec::smallvec![IntegrityAlgorithm::Sha1],
anonymous_username: Feature::Auto,
rto: None,
}
}
pub fn set_allocation_transport(&mut self, allocation_transport: TransportType) {
self.allocation_transport = allocation_transport;
}
pub fn allocation_transport(&self) -> TransportType {
self.allocation_transport
}
pub fn add_address_family(&mut self, family: AddressFamily) {
if !self.address_families.contains(&family) {
self.address_families.push(family);
}
}
pub fn set_address_family(&mut self, family: AddressFamily) {
self.address_families = smallvec::smallvec![family];
}
pub fn address_families(&self) -> &[AddressFamily] {
&self.address_families
}
pub fn credentials(&self) -> &TurnCredentials {
&self.credentials
}
pub fn add_supported_integrity(&mut self, integrity: IntegrityAlgorithm) {
if !self.supported_integrity.contains(&integrity) {
self.supported_integrity.push(integrity);
}
}
pub fn set_supported_integrity(&mut self, integrity: IntegrityAlgorithm) {
self.supported_integrity = smallvec::smallvec![integrity];
}
pub fn supported_integrity(&self) -> &[IntegrityAlgorithm] {
&self.supported_integrity
}
pub fn set_anonymous_username(&mut self, anon: Feature) {
self.anonymous_username = anon;
}
pub fn anonymous_username(&self) -> Feature {
self.anonymous_username
}
pub fn set_request_retransmits(
&mut self,
initial: Duration,
max: Duration,
retransmits: u32,
final_retransmit_timeout: Duration,
) {
let rto = self.rto.get_or_insert(RequestRto {
initial,
max,
retransmits,
final_retransmit_timeout,
});
rto.initial = initial;
rto.max = max;
rto.retransmits = retransmits;
rto.final_retransmit_timeout = final_retransmit_timeout;
}
pub(crate) fn apply_to_stun_builder(&self, builder: StunAgentBuilder) -> StunAgentBuilder {
if let Some(rto) = self.rto.as_ref() {
builder.request_retransmits(
rto.initial,
rto.max,
rto.retransmits,
rto.final_retransmit_timeout,
)
} else {
builder
}
}
}
#[derive(Debug)]
pub enum TurnPollRet {
WaitUntil(Instant),
AllocateTcpSocket {
id: u32,
socket: Socket5Tuple,
peer_addr: SocketAddr,
},
TcpClose {
local_addr: SocketAddr,
remote_addr: SocketAddr,
},
Closed,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Socket5Tuple {
pub transport: TransportType,
pub from: SocketAddr,
pub to: SocketAddr,
}
#[derive(Debug)]
pub enum TurnRecvRet<T: AsRef<[u8]> + core::fmt::Debug> {
Handled,
Ignored(Transmit<T>),
PeerData(TurnPeerData<T>),
PeerIcmp {
transport: TransportType,
peer: SocketAddr,
icmp_type: u8,
icmp_code: u8,
icmp_data: u32,
},
}
#[derive(Debug)]
pub struct TurnPeerData<T: AsRef<[u8]> + core::fmt::Debug> {
pub(crate) data: DataRangeOrOwned<T>,
pub transport: TransportType,
pub peer: SocketAddr,
}
impl<T: AsRef<[u8]> + core::fmt::Debug> TurnPeerData<T> {
pub fn into_owned<R: AsRef<[u8]> + core::fmt::Debug>(self) -> TurnPeerData<R> {
TurnPeerData {
data: self.data.into_owned(),
transport: self.transport,
peer: self.peer,
}
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> TurnPeerData<T> {
pub fn data(&self) -> &[u8] {
self.data.as_ref()
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for TurnPeerData<T> {
fn as_ref(&self) -> &[u8] {
self.data.as_ref()
}
}
#[derive(Debug)]
pub enum TurnEvent {
AllocationCreated(TransportType, SocketAddr),
AllocationCreateFailed(AddressFamily),
PermissionCreated(TransportType, IpAddr),
PermissionCreateFailed(TransportType, IpAddr),
ChannelCreated(TransportType, SocketAddr),
ChannelCreateFailed(TransportType, SocketAddr),
TcpConnected(SocketAddr),
TcpConnectFailed(SocketAddr),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum BindChannelError {
#[error("The channel identifier already exists and cannot be recreated.")]
AlreadyExists,
#[error("The channel for requested peer address has expired and cannot be recreated until {}.", .0)]
ExpiredChannelExists(Instant),
#[error("There is no connection to the TURN server that can handle this channel.")]
NoAllocation,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum CreatePermissionError {
#[error("The permission already exists and cannot be recreated.")]
AlreadyExists,
#[error("There is no connection to the TURN server that can handle this permission")]
NoAllocation,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum DeleteError {
#[error("There is no connection to the TURN server")]
NoAllocation,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SendError {
#[error("There is no connection to the TURN server")]
NoAllocation,
#[error("There is no permission installed for the requested peer")]
NoPermission,
#[error("There is no local TCP socket for the requested peer")]
NoTcpSocket,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TcpConnectError {
#[error("The TCP connection already exists and cannot be recreated.")]
AlreadyExists,
#[error("There is no connection to the TURN server that can handle this TCP socket.")]
NoAllocation,
#[error("There is no permission installed for the requested peer")]
NoPermission,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TcpAllocateError {
#[error("The TCP connection already exists and cannot be recreated.")]
AlreadyExists,
#[error("There is no connection to the TURN server that can handle this TCP socket.")]
NoAllocation,
}
#[derive(Debug)]
pub enum DataRangeOrOwned<T: AsRef<[u8]> + core::fmt::Debug> {
Range {
data: T,
range: Range<usize>,
},
Owned(Vec<u8>),
}
impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for DataRangeOrOwned<T> {
fn as_ref(&self) -> &[u8] {
match self {
Self::Range { data, range } => &data.as_ref()[range.start..range.end],
Self::Owned(owned) => owned,
}
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DataRangeOrOwned<T> {
pub(crate) fn into_owned<R: AsRef<[u8]> + core::fmt::Debug>(self) -> DataRangeOrOwned<R> {
DataRangeOrOwned::Owned(match self {
Self::Range { data: _, range: _ } => self.as_ref().to_vec(),
Self::Owned(owned) => owned,
})
}
}
#[derive(Debug)]
pub struct DelayedTransmit<T: AsRef<[u8]> + core::fmt::Debug> {
data: T,
range: Range<usize>,
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmit<T> {
fn data(&self) -> &[u8] {
&self.data.as_ref()[self.range.clone()]
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedTransmit<T> {
fn len(&self) -> usize {
self.range.len()
}
fn build(self) -> Vec<u8> {
self.data().to_vec()
}
fn write_into(self, data: &mut [u8]) -> usize {
data.copy_from_slice(self.data());
self.len()
}
}
#[derive(Debug)]
pub enum DelayedMessageOrChannelSend<T: AsRef<[u8]> + core::fmt::Debug> {
Channel(DelayedChannel<T>),
Message(DelayedMessage<T>),
Data(T),
OwnedData(Vec<u8>),
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedMessageOrChannelSend<T> {
pub(crate) fn new_channel(data: T, channel_id: u16) -> Self {
Self::Channel(DelayedChannel::new(channel_id, data))
}
pub(crate) fn new_message(data: T, peer_addr: SocketAddr) -> Self {
Self::Message(DelayedMessage::for_server(peer_addr, data))
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> DelayedTransmitBuild for DelayedMessageOrChannelSend<T> {
fn len(&self) -> usize {
match self {
Self::Channel(channel) => channel.len(),
Self::Message(msg) => msg.len(),
Self::Data(data) => data.as_ref().len(),
Self::OwnedData(owned) => owned.len(),
}
}
fn build(self) -> Vec<u8> {
match self {
Self::Channel(channel) => channel.build(),
Self::Message(msg) => msg.build(),
Self::Data(data) => data.as_ref().to_vec(),
Self::OwnedData(owned) => owned,
}
}
fn write_into(self, data: &mut [u8]) -> usize {
match self {
Self::Channel(channel) => channel.write_into(data),
Self::Message(msg) => msg.write_into(data),
Self::Data(slice) => {
data.copy_from_slice(slice.as_ref());
slice.as_ref().len()
}
Self::OwnedData(owned) => {
data.copy_from_slice(&owned);
owned.len()
}
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use alloc::vec;
use super::*;
use turn_types::stun::message::Message;
use turn_types::{
attribute::{Data as AData, XorPeerAddress},
channel::ChannelData,
};
pub(crate) fn generate_addresses() -> (SocketAddr, SocketAddr) {
(
"192.168.0.1:1000".parse().unwrap(),
"10.0.0.2:2000".parse().unwrap(),
)
}
#[test]
fn test_delayed_message() {
let (local_addr, remote_addr) = generate_addresses();
let data = [5; 5];
let peer_addr = "127.0.0.1:1".parse().unwrap();
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
let msg = Message::from_bytes(&out.data).unwrap();
let addr = msg.attribute::<XorPeerAddress>().unwrap();
assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
let out_data = msg.attribute::<AData>().unwrap();
assert_eq!(out_data.data(), data.as_ref());
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Message(DelayedMessage::for_server(peer_addr, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
let msg = Message::from_bytes(&out2).unwrap();
let addr = msg.attribute::<XorPeerAddress>().unwrap();
assert_eq!(addr.addr(msg.transaction_id()), peer_addr);
let out_data = msg.attribute::<AData>().unwrap();
assert_eq!(out_data.data(), data.as_ref());
}
#[test]
fn test_delayed_channel() {
let (local_addr, remote_addr) = generate_addresses();
let data = [5; 5];
let channel_id = 0x4567;
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
let channel = ChannelData::parse(&out.data).unwrap();
assert_eq!(channel.id(), channel_id);
assert_eq!(channel.data(), data.as_ref());
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::Channel(DelayedChannel::new(channel_id, data)),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
assert_eq!(len, out2.len());
let channel = ChannelData::parse(&out2).unwrap();
assert_eq!(channel.id(), channel_id);
assert_eq!(channel.data(), data.as_ref());
}
#[test]
fn test_delayed_owned() {
let (local_addr, remote_addr) = generate_addresses();
let data = vec![7; 7];
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::<Vec<u8>>::Data(data.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
assert!(!transmit.data.is_empty());
let len = transmit.data.len();
let out = transmit.build();
assert_eq!(len, out.data.len());
assert_eq!(data, out.data);
let transmit = TransmitBuild::new(
DelayedMessageOrChannelSend::<Vec<u8>>::Data(data.clone()),
TransportType::Udp,
local_addr,
remote_addr,
);
let mut out2 = vec![0; len];
transmit.write_into(&mut out2);
assert_eq!(len, out2.len());
assert_eq!(data, out2);
}
}