use self::codec::{TunnelCodec, TunnelMessage};
use crate::{
api::create_server_tunnel,
ctx::ClientContext,
servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT},
};
use bytes::Bytes;
use futures::{Sink, Stream};
use log::{debug, error};
use reqwest::Upgraded;
use std::{
future::Future,
io::ErrorKind,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::Duration,
};
use tokio::{io::ReadBuf, net::UdpSocket, sync::mpsc, try_join};
use tokio_util::codec::Framed;
const SOCKET_POOL_SIZE: usize = 4;
const MAX_ERROR_ATTEMPTS: u8 = 5;
static LOCAL_SEND_TARGET: SocketAddr =
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, GAME_HOST_PORT));
pub async fn start_tunnel_server(ctx: Arc<ClientContext>) -> std::io::Result<()> {
let association = match Option::as_ref(&ctx.association) {
Some(value) => value,
None => return Ok(()),
};
let mut last_error: Option<std::io::Error> = None;
let mut attempt_errors: u8 = 0;
while attempt_errors < MAX_ERROR_ATTEMPTS {
let reconnect_time = if let Err(err) = create_tunnel(ctx.clone(), association).await {
error!("Failed to create tunnel: {}", err);
last_error = Some(err);
attempt_errors += 1;
Duration::from_millis(1000 * attempt_errors as u64)
} else {
attempt_errors = 0;
Duration::from_millis(1000)
};
debug!(
"Next tunnel create attempt in: {}s",
reconnect_time.as_secs()
);
tokio::time::sleep(reconnect_time).await;
}
Err(last_error.unwrap_or(std::io::Error::new(
ErrorKind::Other,
"Reached error connect limit",
)))
}
async fn create_tunnel(ctx: Arc<ClientContext>, association: &str) -> std::io::Result<()> {
let io = create_server_tunnel(&ctx.http_client, &ctx.base_url, association)
.await
.map(|io| Framed::new(io, TunnelCodec::default()))
.map_err(|err| std::io::Error::new(ErrorKind::Other, err))?;
debug!("Created server tunnel");
let (tx, rx) = mpsc::unbounded_channel();
let pool = Socket::allocate_pool(tx).await?;
debug!("Allocated tunnel pool");
Tunnel {
io,
rx,
pool,
write_state: Default::default(),
}
.await;
Ok(())
}
struct Tunnel {
io: Framed<Upgraded, TunnelCodec>,
rx: mpsc::UnboundedReceiver<TunnelMessage>,
pool: [SocketHandle; SOCKET_POOL_SIZE],
write_state: TunnelWriteState,
}
#[derive(Default)]
enum TunnelWriteState {
#[default]
Recv,
Write(Option<TunnelMessage>),
Flush,
Stop,
}
enum TunnelReadState {
Continue,
Stop,
}
impl Tunnel {
fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelWriteState> {
Poll::Ready(match &mut self.write_state {
TunnelWriteState::Recv => {
let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
if let Some(message) = result {
TunnelWriteState::Write(Some(message))
} else {
TunnelWriteState::Stop
}
}
TunnelWriteState::Write(message) => {
if ready!(Pin::new(&mut self.io).poll_ready(cx)).is_ok() {
let message = message
.take()
.expect("Unexpected write state without message");
Pin::new(&mut self.io)
.start_send(message)
.expect("Message encoder errored");
TunnelWriteState::Flush
} else {
TunnelWriteState::Stop
}
}
TunnelWriteState::Flush => {
if ready!(Pin::new(&mut self.io).poll_flush(cx)).is_ok() {
TunnelWriteState::Recv
} else {
TunnelWriteState::Stop
}
}
TunnelWriteState::Stop => panic!("Tunnel polled after already stopped"),
})
}
fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelReadState> {
let Some(Ok(message)) = ready!(Pin::new(&mut self.io).poll_next(cx)) else {
return Poll::Ready(TunnelReadState::Stop);
};
if message.index == 255 {
if let TunnelWriteState::Recv = self.write_state {
self.write_state = TunnelWriteState::Write(Some(TunnelMessage {
index: 255,
message: Bytes::new(),
}));
if let Poll::Ready(next_state) = self.poll_write_state(cx) {
self.write_state = next_state;
if let TunnelWriteState::Stop = self.write_state {
return Poll::Ready(TunnelReadState::Stop);
}
}
}
return Poll::Ready(TunnelReadState::Continue);
}
let handle = self.pool.get(message.index as usize);
if let Some(handle) = handle {
_ = handle.0.send(message);
}
Poll::Ready(TunnelReadState::Continue)
}
}
impl Future for Tunnel {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
while let Poll::Ready(next_state) = this.poll_write_state(cx) {
this.write_state = next_state;
if let TunnelWriteState::Stop = this.write_state {
return Poll::Ready(());
}
}
while let Poll::Ready(next_state) = this.poll_read_state(cx) {
if let TunnelReadState::Stop = next_state {
return Poll::Ready(());
}
}
Poll::Pending
}
}
#[derive(Clone)]
struct SocketHandle(mpsc::UnboundedSender<TunnelMessage>);
const READ_BUFFER_LENGTH: usize = 2usize.pow(16);
struct Socket {
index: u8,
socket: UdpSocket,
rx: mpsc::UnboundedReceiver<TunnelMessage>,
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
read_buffer: [u8; READ_BUFFER_LENGTH],
write_state: SocketWriteState,
}
#[derive(Default)]
enum SocketWriteState {
#[default]
Recv,
Write(Bytes),
Stop,
}
enum SocketReadState {
Continue,
Stop,
}
impl Socket {
async fn allocate_pool(
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
) -> std::io::Result<[SocketHandle; SOCKET_POOL_SIZE]> {
let sockets = try_join!(
Socket::start(0, TUNNEL_HOST_PORT, tun_tx.clone()),
Socket::start(1, RANDOM_PORT, tun_tx.clone()),
Socket::start(2, RANDOM_PORT, tun_tx.clone()),
Socket::start(3, RANDOM_PORT, tun_tx),
)?;
Ok(sockets.into())
}
async fn start(
index: u8,
port: u16,
tun_tx: mpsc::UnboundedSender<TunnelMessage>,
) -> std::io::Result<SocketHandle> {
let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, port)).await?;
socket.connect(LOCAL_SEND_TARGET).await?;
let (tx, rx) = mpsc::unbounded_channel();
spawn_server_task(Socket {
index,
socket,
rx,
tun_tx,
read_buffer: [0; READ_BUFFER_LENGTH],
write_state: Default::default(),
});
Ok(SocketHandle(tx))
}
fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketWriteState> {
Poll::Ready(match &mut self.write_state {
SocketWriteState::Recv => {
let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
if let Some(message) = result {
SocketWriteState::Write(message.message)
} else {
SocketWriteState::Stop
}
}
SocketWriteState::Write(message) => {
let Ok(count) = ready!(self.socket.poll_send(cx, message)) else {
return Poll::Ready(SocketWriteState::Stop);
};
if count != message.len() {
let message = message.slice(count..);
SocketWriteState::Write(message)
} else {
SocketWriteState::Recv
}
}
SocketWriteState::Stop => panic!("Tunnel socket polled after already stopped"),
})
}
fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketReadState> {
let mut read_buf = ReadBuf::new(&mut self.read_buffer);
if ready!(self.socket.poll_recv(cx, &mut read_buf)).is_err() {
return Poll::Ready(SocketReadState::Stop);
}
let bytes = read_buf.filled();
let message = Bytes::copy_from_slice(bytes);
let message = TunnelMessage {
index: self.index,
message,
};
Poll::Ready(if self.tun_tx.send(message).is_ok() {
SocketReadState::Continue
} else {
SocketReadState::Stop
})
}
}
impl Future for Socket {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
while let Poll::Ready(next_state) = this.poll_write_state(cx) {
this.write_state = next_state;
if let SocketWriteState::Stop = this.write_state {
return Poll::Ready(());
}
}
while let Poll::Ready(next_state) = this.poll_read_state(cx) {
if let SocketReadState::Stop = next_state {
return Poll::Ready(());
}
}
Poll::Pending
}
}
mod codec {
use bytes::{Buf, BufMut, Bytes};
use tokio_util::codec::{Decoder, Encoder};
struct TunnelMessageHeader {
index: u8,
length: u16,
}
pub struct TunnelMessage {
pub index: u8,
pub message: Bytes,
}
#[derive(Default)]
pub struct TunnelCodec {
partial: Option<TunnelMessageHeader>,
}
impl Decoder for TunnelCodec {
type Item = TunnelMessage;
type Error = std::io::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let partial = match self.partial.as_mut() {
Some(value) => value,
None => {
if src.len() < 5 {
return Ok(None);
}
let index = src.get_u8();
let length = src.get_u16();
self.partial.insert(TunnelMessageHeader { index, length })
}
};
if src.len() < partial.length as usize {
return Ok(None);
}
let partial = self.partial.take().expect("Partial frame missing");
let bytes = src.split_to(partial.length as usize);
Ok(Some(TunnelMessage {
index: partial.index,
message: bytes.freeze(),
}))
}
}
impl Encoder<TunnelMessage> for TunnelCodec {
type Error = std::io::Error;
fn encode(
&mut self,
item: TunnelMessage,
dst: &mut bytes::BytesMut,
) -> Result<(), Self::Error> {
dst.put_u8(item.index);
dst.put_u16(item.message.len() as u16);
dst.extend_from_slice(&item.message);
Ok(())
}
}
}