use crate::{
ctx::ClientContext,
servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT},
};
use log::{debug, error};
use pocket_relay_udp_tunnel::{
deserialize_message, serialize_message, MessageError, TunnelMessage,
};
use std::{
future::Future,
io::ErrorKind,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::Duration,
};
use thiserror::Error;
use tokio::{
io::ReadBuf,
net::UdpSocket,
sync::mpsc,
time::{interval_at, sleep, timeout, Instant, Interval, MissedTickBehavior},
try_join,
};
const SOCKET_POOL_SIZE: usize = 4;
const MAX_ERROR_ATTEMPTS: u8 = 5;
static LOCAL_SEND_TARGET: SocketAddr =
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, GAME_HOST_PORT));
#[derive(Debug, Error)]
pub enum UdpTunnelError {
#[error("host url incompatible with UDP tunnel")]
HostIncompatible,
#[error("server incompatible with UDP tunnel")]
ServerIncompatible,
#[error(transparent)]
Bind(std::io::Error),
#[error(transparent)]
Connect(std::io::Error),
#[error("timeout reached while handshaking")]
HandshakeTimeout,
#[error(transparent)]
GenericIo(#[from] std::io::Error),
#[error("malformed packet: {0}")]
MalformedPacket(#[from] MessageError),
#[error("unexpected packet while handshaking")]
UnexpectedPacket,
#[error(transparent)]
AllocateSocketPool(std::io::Error),
}
pub async fn start_udp_tunnel_server(
ctx: Arc<ClientContext>,
tunnel_port: u16,
) -> std::io::Result<()> {
let host = match ctx.base_url.host() {
Some(value) => value.to_string(),
None => return Ok(()),
};
let association = match Option::as_ref(&ctx.association) {
Some(value) => value,
None => return Ok(()),
};
let mut last_error: Option<UdpTunnelError> = None;
let mut attempt_errors: u8 = 0;
while attempt_errors < MAX_ERROR_ATTEMPTS {
let reconnect_time = if let Err(err) = create_tunnel(&host, tunnel_port, association).await
{
error!("Failed to create tunnel: {}", err);
last_error = Some(err);
attempt_errors += 1;
Duration::from_millis(1000 * attempt_errors as u64)
} else {
attempt_errors = 0;
Duration::from_millis(1000)
};
debug!(
"Next tunnel create attempt in: {}s",
reconnect_time.as_secs()
);
tokio::time::sleep(reconnect_time).await;
}
Err(last_error
.map(|err| std::io::Error::new(ErrorKind::Other, err))
.unwrap_or(std::io::Error::new(
ErrorKind::Other,
"Reached error connect limit",
)))
}
async fn create_tunnel(
host: &str,
tunnel_port: u16,
association: &str,
) -> Result<(), UdpTunnelError> {
let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))
.await
.map_err(UdpTunnelError::Bind)?;
socket
.connect((host, tunnel_port))
.await
.map_err(UdpTunnelError::Connect)?;
debug!("initiating tunnel: {}:{}", host, tunnel_port);
let tunnel_id = attempt_tunnel_handshake(&socket, association).await?;
debug!("created server tunnel: {}", tunnel_id);
let (tx, rx) = mpsc::unbounded_channel();
let pool = Socket::allocate_pool(tx)
.await
.map_err(UdpTunnelError::AllocateSocketPool)?;
debug!("Allocated tunnel pool");
let now = Instant::now();
let keep_alive_start = now + KEEP_ALIVE_DELAY;
let mut keep_alive_interval = interval_at(keep_alive_start, KEEP_ALIVE_DELAY);
keep_alive_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
Tunnel {
socket,
tunnel_id,
rx,
pool,
write_state: Default::default(),
read_buffer: [0u8; u16::MAX as usize],
last_keep_alive: now,
keep_alive_interval,
}
.await;
Ok(())
}
const MAX_HANDSHAKE_ATTEMPTS: u8 = 5;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
async fn attempt_tunnel_handshake(
socket: &UdpSocket,
association: &str,
) -> Result<u32, UdpTunnelError> {
let mut retry_count: u8 = 0;
let mut retry_delay: u64 = 5;
let mut last_err: UdpTunnelError;
loop {
match timeout(HANDSHAKE_TIMEOUT, handshake_tunnel(socket, association)).await {
Ok(Ok(value)) => return Ok(value),
Ok(Err(err)) => {
error!("failed to handshake for token: {}", err);
last_err = err
}
Err(_) => {
error!("timeout while attempting tunnel handshake");
last_err = UdpTunnelError::HandshakeTimeout
}
}
retry_count += 1;
sleep(Duration::from_secs(retry_delay)).await;
retry_delay *= 2;
if retry_count > MAX_HANDSHAKE_ATTEMPTS {
return Err(last_err);
}
}
}
async fn handshake_tunnel(socket: &UdpSocket, association: &str) -> Result<u32, UdpTunnelError> {
let buffer = serialize_message(
u32::MAX,
&TunnelMessage::Initiate {
association_token: association.to_string(),
},
);
socket.send(&buffer).await?;
let mut buffer = [0u8; u16::MAX as usize];
let count = socket.recv(&mut buffer).await?;
let buffer = &buffer[..count];
let packet = deserialize_message(buffer)?;
match packet.message {
TunnelMessage::Initiated { tunnel_id } => Ok(tunnel_id),
_ => Err(UdpTunnelError::UnexpectedPacket),
}
}
const KEEP_ALIVE_DELAY: Duration = Duration::from_secs(10);
const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(KEEP_ALIVE_DELAY.as_secs() * 4);
struct Tunnel {
socket: UdpSocket,
tunnel_id: u32,
rx: mpsc::UnboundedReceiver<TunnelMessage>,
pool: [SocketHandle; SOCKET_POOL_SIZE],
write_state: TunnelWriteState,
read_buffer: [u8; u16::MAX as usize],
last_keep_alive: Instant,
keep_alive_interval: Interval,
}
#[derive(Default)]
enum TunnelWriteState {
#[default]
Recv,
Write(Option<TunnelMessage>),
Stop,
}
enum TunnelReadState {
Continue,
Stop,
}
impl Tunnel {
fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelWriteState> {
Poll::Ready(match &mut self.write_state {
TunnelWriteState::Recv => {
let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
if let Some(message) = result {
TunnelWriteState::Write(Some(message))
} else {
TunnelWriteState::Stop
}
}
TunnelWriteState::Write(message) => {
if ready!(Pin::new(&mut self.socket).poll_send_ready(cx)).is_ok() {
let message = message
.take()
.expect("Unexpected write state without message");
let buffer = serialize_message(self.tunnel_id, &message);
ready!(Pin::new(&mut self.socket).poll_send(cx, &buffer))
.expect("Message encoder errored");
TunnelWriteState::Recv
} else {
TunnelWriteState::Stop
}
}
TunnelWriteState::Stop => panic!("Tunnel polled after already stopped"),
})
}
fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelReadState> {
if self.keep_alive_interval.poll_tick(cx).is_ready() {
debug!("checking connection alive");
let now = Instant::now();
let last_alive = self.last_keep_alive.duration_since(now);
if last_alive > KEEP_ALIVE_TIMEOUT {
return Poll::Ready(TunnelReadState::Stop);
}
}
if ready!(Pin::new(&mut self.socket).poll_recv_ready(cx)).is_err() {
return Poll::Ready(TunnelReadState::Stop);
};
let mut read_buffer = ReadBuf::new(&mut self.read_buffer);
if ready!(Pin::new(&mut self.socket).poll_recv(cx, &mut read_buffer)).is_err() {
return Poll::Ready(TunnelReadState::Stop);
};
let buffer = read_buffer.filled();
let packet = match deserialize_message(buffer) {
Ok(value) => value,
Err(err) => {
error!("encountered invalid tunnel message: {}", err);
return Poll::Ready(TunnelReadState::Stop);
}
};
match packet.message {
TunnelMessage::Forward { index, message } => {
let handle = self.pool.get(index as usize);
if let Some(handle) = handle {
_ = handle.0.send(message);
}
}
TunnelMessage::KeepAlive => {
self.last_keep_alive = Instant::now();
self.write_state = TunnelWriteState::Write(Some(TunnelMessage::KeepAlive));
if let Poll::Ready(next_state) = self.poll_write_state(cx) {
self.write_state = next_state;
if let TunnelWriteState::Stop = self.write_state {
return Poll::Ready(TunnelReadState::Stop);
}
}
}
_ => {
}
}
Poll::Ready(TunnelReadState::Continue)
}
}
impl Future for Tunnel {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
while let Poll::Ready(next_state) = this.poll_write_state(cx) {
this.write_state = next_state;
if let TunnelWriteState::Stop = this.write_state {
return Poll::Ready(());
}
}
while let Poll::Ready(next_state) = this.poll_read_state(cx) {
if let TunnelReadState::Stop = next_state {
return Poll::Ready(());
}
}
Poll::Pending
}
}
#[derive(Clone)]
struct SocketHandle(mpsc::UnboundedSender<Vec<u8>>);
const READ_BUFFER_LENGTH: usize = 2usize.pow(16);
struct Socket {
index: u8,
socket: UdpSocket,
rx: mpsc::UnboundedReceiver<Vec<u8>>,
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
read_buffer: [u8; READ_BUFFER_LENGTH],
write_state: SocketWriteState,
}
#[derive(Default)]
enum SocketWriteState {
#[default]
Recv,
Write(Vec<u8>),
Stop,
}
enum SocketReadState {
Continue,
Stop,
}
impl Socket {
async fn allocate_pool(
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
) -> std::io::Result<[SocketHandle; SOCKET_POOL_SIZE]> {
let sockets = try_join!(
Socket::start(0, TUNNEL_HOST_PORT, tun_tx.clone()),
Socket::start(1, RANDOM_PORT, tun_tx.clone()),
Socket::start(2, RANDOM_PORT, tun_tx.clone()),
Socket::start(3, RANDOM_PORT, tun_tx),
)?;
Ok(sockets.into())
}
async fn start(
index: u8,
port: u16,
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
) -> std::io::Result<SocketHandle> {
let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, port)).await?;
socket.connect(LOCAL_SEND_TARGET).await?;
let (tx, rx) = mpsc::unbounded_channel();
spawn_server_task(Socket {
index,
socket,
rx,
tun_tx,
read_buffer: [0; READ_BUFFER_LENGTH],
write_state: Default::default(),
});
Ok(SocketHandle(tx))
}
fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketWriteState> {
Poll::Ready(match &mut self.write_state {
SocketWriteState::Recv => {
let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
if let Some(message) = result {
SocketWriteState::Write(message)
} else {
SocketWriteState::Stop
}
}
SocketWriteState::Write(message) => {
let Ok(count) = ready!(self.socket.poll_send(cx, message)) else {
return Poll::Ready(SocketWriteState::Stop);
};
if count != message.len() {
let remaining = message.split_off(count);
SocketWriteState::Write(remaining)
} else {
SocketWriteState::Recv
}
}
SocketWriteState::Stop => panic!("Tunnel socket polled after already stopped"),
})
}
fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketReadState> {
let mut read_buf = ReadBuf::new(&mut self.read_buffer);
if ready!(self.socket.poll_recv(cx, &mut read_buf)).is_err() {
return Poll::Ready(SocketReadState::Stop);
}
let bytes = read_buf.filled();
let message = TunnelMessage::Forward {
index: self.index,
message: bytes.to_vec(),
};
Poll::Ready(if self.tun_tx.send(message).is_ok() {
SocketReadState::Continue
} else {
SocketReadState::Stop
})
}
}
impl Future for Socket {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
while let Poll::Ready(next_state) = this.poll_write_state(cx) {
this.write_state = next_state;
if let SocketWriteState::Stop = this.write_state {
return Poll::Ready(());
}
}
while let Poll::Ready(next_state) = this.poll_read_state(cx) {
if let SocketReadState::Stop = next_state {
return Poll::Ready(());
}
}
Poll::Pending
}
}