use {
super::{ServerConfig, ServerError, ToConnected, ToOpen},
crate::{
server::{HandshakeHandler, ToConnecting},
session::SessionError,
},
aeronet_io::{connection::DisconnectReason, server::CloseReason},
bevy_ecs::prelude::*,
core::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
},
futures::{
FutureExt, SinkExt, StreamExt,
channel::{mpsc, oneshot},
never::Never,
},
tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpListener, TcpStream},
},
tokio_rustls::TlsAcceptor,
tokio_tungstenite::tungstenite::{
handshake::server::{Request, Response},
protocol::WebSocketConfig,
},
tracing::{Instrument, debug, debug_span},
};
pub async fn start(
config: ServerConfig,
tx_next: oneshot::Sender<ToOpen>,
) -> Result<Never, CloseReason> {
let tls_acceptor = config.tls.map(TlsAcceptor::from);
let listener = TcpListener::bind(config.bind_address)
.await
.map_err(ServerError::BindSocket)?;
debug!("Listening on {}", config.bind_address);
let (tx_connecting, rx_connecting) = mpsc::channel::<ToConnecting>(1);
let (tx_dropped, mut rx_dropped) = mpsc::channel::<()>(0);
let local_addr = listener.local_addr().map_err(SessionError::GetLocalAddr)?;
let next = ToOpen {
local_addr,
rx_connecting,
tx_dropped,
};
tx_next
.send(next)
.map_err(|_| SessionError::FrontendClosed)?;
debug!("Starting server loop");
loop {
let result = futures::select! {
x = listener.accept().fuse() => x,
_ = rx_dropped.next() => {
return Err(CloseReason::ByError(SessionError::FrontendClosed.into()));
}
};
let (stream, peer_addr) = result.map_err(ServerError::AcceptConnection)?;
tokio::spawn({
let tx_connecting = tx_connecting.clone();
let tls_acceptor = tls_acceptor.clone();
let handshake_handler = config.handshake_handler.clone();
async move {
if let Err(err) = accept_session(
stream,
peer_addr,
config.socket,
tls_acceptor,
tx_connecting,
handshake_handler,
)
.await
{
debug!("Failed to accept session: {err:?}");
}
}
});
}
}
async fn accept_session(
stream: TcpStream,
peer_addr: SocketAddr,
socket_config: WebSocketConfig,
tls_acceptor: Option<TlsAcceptor>,
mut tx_connecting: mpsc::Sender<ToConnecting>,
handshake_handler: Option<HandshakeHandler>,
) -> Result<(), DisconnectReason> {
let (tx_session_entity, rx_session_entity) = oneshot::channel::<Entity>();
let (tx_dc_reason, rx_dc_reason) = oneshot::channel::<DisconnectReason>();
let (tx_next, rx_next) = oneshot::channel::<ToConnected>();
tx_connecting
.send(ToConnecting {
peer_addr,
tx_session_entity,
rx_dc_reason,
rx_next,
})
.await
.map_err(|_| SessionError::FrontendClosed)?;
let session = rx_session_entity
.await
.map_err(|_| SessionError::FrontendClosed)?;
let Err(dc_reason) = handle_session(
stream,
peer_addr,
socket_config,
tls_acceptor,
tx_next,
handshake_handler,
)
.instrument(debug_span!("session", %session))
.await;
_ = tx_dc_reason.send(dc_reason);
Ok(())
}
async fn handle_session(
stream: TcpStream,
peer_addr: SocketAddr,
socket_config: WebSocketConfig,
tls_acceptor: Option<TlsAcceptor>,
tx_next: oneshot::Sender<ToConnected>,
handshake_handler: Option<HandshakeHandler>,
) -> Result<Never, DisconnectReason> {
debug!("Performing session handshake");
let stream = if let Some(tls_acceptor) = tls_acceptor {
tls_acceptor
.accept(stream)
.await
.map(MaybeTlsStream::Rustls)
.map_err(ServerError::TlsHandshake)?
} else {
MaybeTlsStream::Plain(stream)
};
let stream = tokio_tungstenite::accept_hdr_async_with_config(
stream,
#[expect(
clippy::result_large_err,
reason = "this `Result` is what `tokio_tungstenite` asks for"
)]
|req: &Request, resp: Response| match &handshake_handler {
Some(h) => h.handle(req, resp),
None => Ok(resp),
},
Some(socket_config),
)
.await
.map_err(ServerError::AcceptClient)?;
let (frontend, backend) = crate::session::backend::native::split(stream);
let connected = ToConnected {
peer_addr,
frontend,
};
debug!("Connected");
tx_next
.send(connected)
.map_err(|_| SessionError::FrontendClosed)?;
debug!("Starting session loop");
backend.start().await
}
#[derive(Debug)]
#[expect(clippy::large_enum_variant, reason = "most users will use `Rustls`")]
enum MaybeTlsStream<S> {
Plain(S),
Rustls(tokio_rustls::server::TlsStream<S>),
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_read(cx, buf),
Self::Rustls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_write(cx, buf),
Self::Rustls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_flush(cx),
Self::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_shutdown(cx),
Self::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}