use std::{collections::HashMap, net::SocketAddr, sync::LazyLock, time::Duration};
use anyhow::{Result, anyhow, ensure};
use base64::{Engine, prelude::BASE64_STANDARD};
use bytes::BytesMut;
use rand::seq::SliceRandom;
use tokio::{
net::UdpSocket,
time::{sleep, timeout},
};
use turn_server::{
config::{Api, Auth, Config, Interface, Log, Server},
prelude::*,
start_server,
};
static TOKEN: LazyLock<[u8; 12]> = LazyLock::new(|| {
let mut rng = rand::rng();
let mut token = [0u8; 12];
token.shuffle(&mut rng);
token
});
pub async fn spawn_turn_server(bind: SocketAddr, auth: Auth, api: Api) -> Result<()> {
tokio::spawn(async move {
start_server(Config {
log: Log::default(),
server: Server {
realm: "localhost".to_string(),
interfaces: vec![Interface::Udp {
external: bind,
listen: bind,
idle_timeout: 30,
mtu: 1500,
}],
..Default::default()
},
auth,
api: Some(api),
..Default::default()
})
.await
.unwrap();
});
sleep(Duration::from_secs(3)).await;
Ok(())
}
struct TurnTransport {
decoder: Decoder,
socket: UdpSocket,
recv_bytes: [u8; 1500],
send_bytes: BytesMut,
}
impl TurnTransport {
async fn new(server: SocketAddr) -> Result<Self> {
let socket = UdpSocket::bind("127.0.0.1:0").await?;
socket.connect(server).await?;
Ok(Self {
send_bytes: BytesMut::with_capacity(1500),
decoder: Decoder::default(),
recv_bytes: [0u8; 1500],
socket,
})
}
fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.socket.local_addr()?)
}
fn prepare_message(&mut self, method: Method) -> MessageEncoder<'_> {
MessageEncoder::new(method, &TOKEN, &mut self.send_bytes)
}
fn queue_channel_frame(&mut self, number: u16, bytes: &[u8]) {
ChannelData::new(number, bytes).encode(&mut self.send_bytes);
}
async fn send(&self) -> Result<()> {
self.socket.send(&self.send_bytes).await?;
Ok(())
}
async fn receive_message(&mut self) -> Result<Message<'_>> {
let size = timeout(
Duration::from_secs(1),
self.socket.recv(&mut self.recv_bytes),
)
.await??;
if let DecodeResult::Message(message) = self.decoder.decode(&self.recv_bytes[..size])? {
if message.transaction_id() != TOKEN.as_slice() {
Err(anyhow!("Message token does not match"))
} else {
Ok(message)
}
} else {
Err(anyhow!("payload not a message"))
}
}
async fn receive_channel_data(&mut self) -> Result<ChannelData<'_>> {
let size = timeout(
Duration::from_secs(1),
self.socket.recv(&mut self.recv_bytes),
)
.await??;
if let DecodeResult::ChannelData(channel_data) =
self.decoder.decode(&self.recv_bytes[..size])?
{
Ok(channel_data)
} else {
Err(anyhow!("payload not a channel data"))
}
}
}
pub struct Credentials {
pub username: String,
pub password: String,
}
pub struct ChannelFrame<'a> {
pub number: u16,
pub payload: &'a [u8],
}
pub struct Indication<'a> {
pub peer_port: u16,
pub payload: &'a [u8],
}
struct SessionState {
integrity_key: Password,
nonce: String,
realm: String,
}
impl Default for SessionState {
fn default() -> Self {
Self {
integrity_key: Password::Md5([0u8; 16]),
nonce: Default::default(),
realm: Default::default(),
}
}
}
pub struct TurnClient {
transport: TurnTransport,
credentials: Credentials,
server: SocketAddr,
state: SessionState,
}
impl TurnClient {
pub async fn new(server: SocketAddr, credentials: Credentials) -> Result<Self> {
Ok(Self {
transport: TurnTransport::new(server).await?,
state: SessionState::default(),
credentials,
server,
})
}
pub async fn perform_binding(&mut self) -> Result<()> {
{
let mut message = self.transport.prepare_message(BINDING_REQUEST);
message.flush(None)?;
self.transport.send().await?;
}
let local_addr = self.transport.local_addr()?;
let message = self.transport.receive_message().await?;
ensure!(message.method() == BINDING_RESPONSE);
ensure!(message.get::<XorMappedAddress>() == Some(local_addr));
ensure!(message.get::<MappedAddress>() == Some(local_addr));
ensure!(message.get::<ResponseOrigin>() == Some(self.server));
Ok(())
}
pub async fn request_allocation(&mut self) -> Result<u16> {
{
{
let mut message = self.transport.prepare_message(ALLOCATE_REQUEST);
message.append::<RequestedTransport>(RequestedTransport::Udp);
message.flush(None)?;
self.transport.send().await?;
}
let message = self.transport.receive_message().await?;
ensure!(message.method() == ALLOCATE_ERROR);
ensure!(message.get::<ErrorCode>().unwrap().code == ErrorType::Unauthorized as u16);
self.state.nonce = message.get::<Nonce>().unwrap().to_string();
self.state.realm = message.get::<Realm>().unwrap().to_string();
self.state.integrity_key = generate_password(
&self.credentials.username,
&self.credentials.password,
&self.state.realm,
PasswordAlgorithm::Md5,
);
}
{
let mut message = self.transport.prepare_message(ALLOCATE_REQUEST);
message.append::<RequestedTransport>(RequestedTransport::Udp);
message.append::<UserName>(&self.credentials.username);
message.append::<Realm>(&self.state.realm);
message.append::<Nonce>(&self.state.nonce);
message.flush(Some(&self.state.integrity_key))?;
self.transport.send().await?;
}
let local_addr = self.transport.local_addr()?;
let message = self.transport.receive_message().await?;
ensure!(message.method() == ALLOCATE_RESPONSE);
message.verify(&self.state.integrity_key)?;
let relay = message.get::<XorRelayedAddress>().unwrap();
ensure!(relay.ip() == self.server.ip());
ensure!(message.get::<XorMappedAddress>() == Some(local_addr));
ensure!(message.get::<Lifetime>() == Some(600));
Ok(relay.port())
}
pub async fn grant_permission(&mut self, port: u16) -> Result<()> {
{
let mut peer = self.server.clone();
peer.set_port(port);
let mut message = self.transport.prepare_message(CREATE_PERMISSION_REQUEST);
message.append::<XorPeerAddress>(peer);
message.append::<UserName>(&self.credentials.username);
message.append::<Realm>(&self.state.realm);
message.append::<Nonce>(&self.state.nonce);
message.flush(Some(&self.state.integrity_key))?;
self.transport.send().await?;
}
let message = self.transport.receive_message().await?;
ensure!(message.method() == CREATE_PERMISSION_RESPONSE);
message.verify(&self.state.integrity_key)?;
Ok(())
}
pub async fn bind_channel(&mut self, port: u16, channel: u16) -> Result<()> {
{
let mut peer = self.server.clone();
peer.set_port(port);
let mut message = self.transport.prepare_message(CHANNEL_BIND_REQUEST);
message.append::<ChannelNumber>(channel);
message.append::<XorPeerAddress>(peer);
message.append::<UserName>(&self.credentials.username);
message.append::<Realm>(&self.state.realm);
message.append::<Nonce>(&self.state.nonce);
message.flush(Some(&self.state.integrity_key))?;
self.transport.send().await?;
}
let message = self.transport.receive_message().await?;
ensure!(message.method() == CHANNEL_BIND_RESPONSE);
message.verify(&self.state.integrity_key)?;
Ok(())
}
pub async fn refresh_allocation(&mut self, lifetime: u32) -> Result<()> {
{
let mut message = self.transport.prepare_message(REFRESH_REQUEST);
message.append::<Lifetime>(lifetime);
message.append::<UserName>(&self.credentials.username);
message.append::<Realm>(&self.state.realm);
message.append::<Nonce>(&self.state.nonce);
message.flush(Some(&self.state.integrity_key))?;
self.transport.send().await?;
}
let message = self.transport.receive_message().await?;
ensure!(message.method() == REFRESH_RESPONSE);
message.verify(&self.state.integrity_key)?;
ensure!(message.get::<Lifetime>() == Some(lifetime));
Ok(())
}
pub async fn send_indication_to(&mut self, port: u16, data: &[u8]) -> Result<()> {
let mut peer = self.server.clone();
peer.set_port(port);
let mut message = self.transport.prepare_message(SEND_INDICATION);
message.append::<XorPeerAddress>(peer);
message.append::<Data>(data);
message.flush(None)?;
self.transport.send().await?;
Ok(())
}
pub async fn receive_indication(&mut self) -> Result<Indication<'_>> {
let message = self.transport.receive_message().await?;
ensure!(message.method() == DATA_INDICATION);
let peer = message.get::<XorPeerAddress>().unwrap();
let data = message.get::<Data>().unwrap();
Ok(Indication {
peer_port: peer.port(),
payload: data,
})
}
pub async fn send_channel_frame(&mut self, channel: u16, data: &[u8]) -> Result<()> {
self.transport.queue_channel_frame(channel, data);
self.transport.send().await?;
Ok(())
}
pub async fn receive_channel_frame(&mut self) -> Result<ChannelFrame<'_>> {
let frame = self.transport.receive_channel_data().await?;
Ok(ChannelFrame {
number: frame.number(),
payload: frame.bytes(),
})
}
}
fn encode_password(username: &str, password: &str) -> Result<String> {
Ok(BASE64_STANDARD.encode(hmac_sha1(password.as_bytes(), &[username.as_bytes()]).as_slice()))
}
#[tokio::test]
async fn integration_testing() -> Result<()> {
spawn_turn_server(
"127.0.0.1:3478".parse()?,
Auth {
enable_hooks_auth: false,
static_auth_secret: Some("static_auth_secret".to_string()),
static_credentials: {
let mut it = HashMap::with_capacity(1);
it.insert(
"static_credentials".to_string(),
"static_credentials".to_string(),
);
it
},
},
Api::default(),
)
.await?;
let mut turn_1 = TurnClient::new(
"127.0.0.1:3478".parse()?,
Credentials {
username: "static_credentials".to_string(),
password: "static_credentials".to_string(),
},
)
.await?;
let mut turn_2 = TurnClient::new(
"127.0.0.1:3478".parse()?,
Credentials {
username: "static_credentials".to_string(),
password: "static_credentials".to_string(),
},
)
.await?;
let mut turn_3 = TurnClient::new(
"127.0.0.1:3478".parse()?,
Credentials {
username: "static_auth_secret".to_string(),
password: encode_password("static_auth_secret", "static_auth_secret")?,
},
)
.await?;
let mut turn_4 = TurnClient::new(
"127.0.0.1:3478".parse()?,
Credentials {
username: "static_auth_secret".to_string(),
password: encode_password("static_auth_secret", "static_auth_secret")?,
},
)
.await?;
{
turn_1.perform_binding().await?;
turn_2.perform_binding().await?;
turn_3.perform_binding().await?;
}
let turn_1_port = turn_1.request_allocation().await?;
let turn_2_port = turn_2.request_allocation().await?;
let turn_3_port = turn_3.request_allocation().await?;
let turn_4_port = turn_4.request_allocation().await?;
assert_eq!(turn_1.request_allocation().await?, turn_1_port);
assert_eq!(turn_2.request_allocation().await?, turn_2_port);
assert_eq!(turn_3.request_allocation().await?, turn_3_port);
assert_eq!(turn_4.request_allocation().await?, turn_4_port);
{
turn_1.grant_permission(turn_2_port).await?;
turn_1.grant_permission(turn_3_port).await?;
turn_1.grant_permission(turn_4_port).await?;
turn_1.bind_channel(turn_2_port, 0x4000).await?;
turn_1.bind_channel(turn_3_port, 0x4001).await?;
turn_1.bind_channel(turn_4_port, 0x4002).await?;
turn_1.refresh_allocation(600).await?;
turn_2.grant_permission(turn_1_port).await?;
turn_2.grant_permission(turn_3_port).await?;
turn_2.bind_channel(turn_1_port, 0x4000).await?;
turn_2.bind_channel(turn_3_port, 0x4002).await?;
turn_2.refresh_allocation(600).await?;
turn_3.grant_permission(turn_1_port).await?;
turn_3.grant_permission(turn_2_port).await?;
turn_3.bind_channel(turn_1_port, 0x4001).await?;
turn_3.bind_channel(turn_2_port, 0x4002).await?;
turn_3.refresh_allocation(600).await?;
turn_4.grant_permission(turn_1_port).await?;
turn_4.bind_channel(turn_1_port, 0x4002).await?;
turn_4.refresh_allocation(600).await?;
assert!(turn_1.bind_channel(turn_2_port, 0x4000).await.is_ok());
assert!(turn_1.bind_channel(turn_3_port, 0x4001).await.is_ok());
assert!(turn_1.bind_channel(turn_4_port, 0x4002).await.is_ok());
assert!(turn_2.bind_channel(turn_1_port, 0x4000).await.is_ok());
assert!(turn_2.bind_channel(turn_3_port, 0x4002).await.is_ok());
assert!(turn_3.bind_channel(turn_1_port, 0x4001).await.is_ok());
assert!(turn_3.bind_channel(turn_2_port, 0x4002).await.is_ok());
assert!(turn_4.bind_channel(turn_1_port, 0x4002).await.is_ok());
}
{
let data = "1 forwards to 2,3,4 channel data".as_bytes();
turn_1.send_channel_frame(0x4000, data).await?;
let ret = turn_2.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4000);
assert_eq!(ret.payload, data);
turn_1.send_channel_frame(0x4001, data).await?;
let ret = turn_3.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4001);
assert_eq!(ret.payload, data);
turn_1.send_channel_frame(0x4002, data).await?;
let ret = turn_4.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4002);
assert_eq!(ret.payload, data);
}
{
let data = "2 forwards to 1,3 channel data".as_bytes();
turn_2.send_channel_frame(0x4000, data).await?;
let ret = turn_1.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4000);
assert_eq!(ret.payload, data);
assert!(turn_3.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
turn_2.send_channel_frame(0x4002, data).await?;
let ret = turn_3.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4002);
assert_eq!(ret.payload, data);
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
turn_2.send_channel_frame(0x4001, data).await?;
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_3.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
}
{
let data = "3 forwards to 1,2 channel data".as_bytes();
turn_3.send_channel_frame(0x4001, data).await?;
let ret = turn_1.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4001);
assert_eq!(ret.payload, data);
assert!(turn_2.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
turn_3.send_channel_frame(0x4002, data).await?;
let ret = turn_2.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4002);
assert_eq!(ret.payload, data);
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
turn_3.send_channel_frame(0x4000, data).await?;
assert!(turn_2.receive_channel_frame().await.is_err());
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_4.receive_channel_frame().await.is_err());
}
{
let data = "4 forwards to 1 channel data".as_bytes();
turn_4.send_channel_frame(0x4002, data).await?;
let ret = turn_1.receive_channel_frame().await?;
assert_eq!(ret.number, 0x4002);
assert_eq!(ret.payload, data);
assert!(turn_2.receive_channel_frame().await.is_err());
assert!(turn_3.receive_channel_frame().await.is_err());
turn_4.send_channel_frame(0x4000, data).await?;
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_2.receive_channel_frame().await.is_err());
assert!(turn_3.receive_channel_frame().await.is_err());
turn_4.send_channel_frame(0x4001, data).await?;
assert!(turn_1.receive_channel_frame().await.is_err());
assert!(turn_2.receive_channel_frame().await.is_err());
assert!(turn_3.receive_channel_frame().await.is_err());
}
{
let data = "1 forwards to 2,3,4".as_bytes();
turn_1.send_indication_to(turn_2_port, data).await?;
let ret = turn_2.receive_indication().await?;
assert_eq!(ret.peer_port, turn_1_port);
assert_eq!(ret.payload, data);
turn_1.send_indication_to(turn_3_port, data).await?;
let ret = turn_3.receive_indication().await?;
assert_eq!(ret.peer_port, turn_1_port);
assert_eq!(ret.payload, data);
turn_1.send_indication_to(turn_4_port, data).await?;
let ret = turn_4.receive_indication().await?;
assert_eq!(ret.peer_port, turn_1_port);
assert_eq!(ret.payload, data);
}
{
let data = "2 forwards to 1,3".as_bytes();
turn_2.send_indication_to(turn_1_port, data).await?;
let ret = turn_1.receive_indication().await?;
assert_eq!(ret.peer_port, turn_2_port);
assert_eq!(ret.payload, data);
turn_2.send_indication_to(turn_3_port, data).await?;
let ret = turn_3.receive_indication().await?;
assert_eq!(ret.peer_port, turn_2_port);
assert_eq!(ret.payload, data);
turn_2.send_indication_to(turn_4_port, data).await?;
assert!(turn_4.receive_indication().await.is_err());
}
{
let data = "3 forwards to 1,2".as_bytes();
turn_3.send_indication_to(turn_1_port, data).await?;
let ret = turn_1.receive_indication().await?;
assert_eq!(ret.peer_port, turn_3_port);
assert_eq!(ret.payload, data);
turn_3.send_indication_to(turn_2_port, data).await?;
let ret = turn_2.receive_indication().await?;
assert_eq!(ret.peer_port, turn_3_port);
assert_eq!(ret.payload, data);
turn_3.send_indication_to(turn_4_port, data).await?;
assert!(turn_4.receive_indication().await.is_err());
}
{
let data = "4 forwards to 1".as_bytes();
turn_4.send_indication_to(turn_1_port, data).await?;
let ret = turn_1.receive_indication().await?;
assert_eq!(ret.peer_port, turn_4_port);
assert_eq!(ret.payload, data);
turn_4.send_indication_to(turn_3_port, data).await?;
assert!(turn_3.receive_indication().await.is_err());
}
{
turn_1.refresh_allocation(0).await?;
turn_2.refresh_allocation(0).await?;
turn_3.refresh_allocation(0).await?;
}
Ok(())
}