use crate::socket::CloseCode;
use crate::socket::InMessage;
use crate::CloseFrame;
use crate::Error;
use crate::Message;
use crate::Request;
use crate::Session;
use crate::SessionExt;
use crate::Socket;
use async_trait::async_trait;
use std::net::SocketAddr;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;
struct NewConnection {
socket: Socket,
address: SocketAddr,
request: Request,
}
struct Disconnected<E: ServerExt> {
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
}
struct ServerActor<E: ServerExt> {
connection_receiver: mpsc::UnboundedReceiver<NewConnection>,
disconnection_receiver: mpsc::UnboundedReceiver<Disconnected<E>>,
server_call_receiver: mpsc::UnboundedReceiver<E::Call>,
shutdown_receiver: mpsc::UnboundedReceiver<oneshot::Sender<()>>,
shutdown_signal: watch::Sender<bool>,
server: Server<E>,
extension: E,
}
impl<E: ServerExt> ServerActor<E>
where
E: Send + 'static,
<E::Session as SessionExt>::ID: Send,
{
async fn run(mut self) {
tracing::info!("starting websocket server");
let mut active_sessions: usize = 0;
let mut shutting_down = false;
let mut shutdown_acks: Vec<oneshot::Sender<()>> = Vec::new();
loop {
if shutting_down && active_sessions == 0 {
tracing::info!("server shutdown complete");
for ack in shutdown_acks.drain(..) {
let _ = ack.send(());
}
break;
}
if let Err(err) = async {
tokio::select! {
Some(ack) = self.shutdown_receiver.recv() => {
shutdown_acks.push(ack);
if !shutting_down {
tracing::info!(active_sessions, "graceful shutdown initiated");
shutting_down = true;
let _ = self.shutdown_signal.send(true);
if active_sessions == 0 {
}
}
}
Some(NewConnection{socket, address, request}) = self.connection_receiver.recv() => {
if shutting_down {
tracing::info!("rejecting connection from {address} during shutdown");
let _ = socket
.sink
.send(InMessage::new(Message::Close(Some(CloseFrame {
code: CloseCode::Away,
reason: "server shutting down".into(),
}))))
.await;
return Ok(());
}
let socket_sink = socket.sink.clone();
match self.extension.on_connect(socket, request, address).await {
Ok(session) => {
tracing::info!("connection from {address} accepted");
let session_id = session.id.clone();
active_sessions += 1;
let shutdown_rx = self.shutdown_signal.subscribe();
tokio::spawn({
let server = self.server.clone();
let close_session = session.clone();
async move {
let close_task = tokio::spawn(watch_for_shutdown(shutdown_rx, close_session));
let result = session.await_close().await;
close_task.abort();
server.disconnected(session_id, result);
}
});
}
Err(err) => {
tracing::info!(?err, "connection from {address} rejected");
if let Err(err) = socket_sink.send(InMessage::new(Message::Close(err))).await {
tracing::warn!(?err, "failed forwarding close frame to socket after connection rejected");
}
}
}
}
Some(Disconnected{id, result}) = self.disconnection_receiver.recv() => {
match &result {
Ok(Some(CloseFrame { code, reason })) => {
tracing::info!(%id, ?code, %reason, "connection closed")
}
Ok(None) => tracing::info!(%id, "connection closed"),
Err(err) => tracing::warn!(%id, "connection closed due to: {err:?}"),
};
active_sessions = active_sessions.saturating_sub(1);
self.extension.on_disconnect(id.clone(), result).await?;
}
Some(call) = self.server_call_receiver.recv(), if !shutting_down => {
self.extension.on_call(call).await?
}
else => return Err("server actor branches broken".into()),
}
Ok::<_, Error>(())
}
.await {
tracing::warn!("error when processing: {err:?}");
}
}
}
}
async fn watch_for_shutdown<I, C>(mut rx: watch::Receiver<bool>, session: Session<I, C>)
where
I: std::fmt::Display + Clone + Send + 'static,
C: Send + 'static,
{
loop {
if *rx.borrow() {
break;
}
if rx.changed().await.is_err() {
return;
}
}
let _ = session.close(Some(CloseFrame {
code: CloseCode::Away,
reason: "server shutting down".into(),
}));
}
#[async_trait]
pub trait ServerExt: Send {
type Session: SessionExt;
type Call: Send;
async fn on_connect(
&mut self,
socket: Socket,
request: Request,
address: SocketAddr,
) -> Result<
Session<<Self::Session as SessionExt>::ID, <Self::Session as SessionExt>::Call>,
Option<CloseFrame>,
>;
async fn on_disconnect(
&mut self,
id: <Self::Session as SessionExt>::ID,
reason: Result<Option<CloseFrame>, Error>,
) -> Result<(), Error>;
async fn on_call(&mut self, call: Self::Call) -> Result<(), Error>;
}
#[derive(Debug)]
pub struct Server<E: ServerExt> {
connection_sender: mpsc::UnboundedSender<NewConnection>,
disconnection_sender: mpsc::UnboundedSender<Disconnected<E>>,
server_call_sender: mpsc::UnboundedSender<E::Call>,
shutdown_sender: mpsc::UnboundedSender<oneshot::Sender<()>>,
}
impl<E: ServerExt> From<Server<E>> for mpsc::UnboundedSender<E::Call> {
fn from(server: Server<E>) -> Self {
server.server_call_sender
}
}
impl<E: ServerExt + 'static> Server<E> {
pub fn create(create: impl FnOnce(Self) -> E) -> (Self, JoinHandle<()>) {
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();
let (disconnection_sender, disconnection_receiver) = mpsc::unbounded_channel();
let (server_call_sender, server_call_receiver) = mpsc::unbounded_channel();
let (shutdown_sender, shutdown_receiver) = mpsc::unbounded_channel();
let (shutdown_signal, _) = watch::channel(false);
let handle = Self {
connection_sender,
server_call_sender,
disconnection_sender,
shutdown_sender,
};
let extension = create(handle.clone());
let actor = ServerActor {
connection_receiver,
disconnection_receiver,
server_call_receiver,
shutdown_receiver,
shutdown_signal,
extension,
server: handle.clone(),
};
let future = tokio::spawn(actor.run());
(handle, future)
}
}
impl<E: ServerExt> Server<E> {
pub fn accept(&self, socket: Socket, request: Request, address: SocketAddr) {
if self
.connection_sender
.send(NewConnection {
socket,
request,
address,
})
.is_err()
{
tracing::error!("accepted a connection but the server actor is dead");
}
}
pub(crate) fn disconnected(
&self,
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
) {
self.disconnection_sender
.send(Disconnected { id, result })
.map_err(|_| ())
.unwrap_or_default();
}
pub fn call(&self, call: E::Call) -> Result<(), mpsc::error::SendError<E::Call>> {
self.server_call_sender.send(call)
}
pub async fn call_with<R: std::fmt::Debug>(
&self,
f: impl FnOnce(oneshot::Sender<R>) -> E::Call,
) -> Option<R> {
let (sender, receiver) = oneshot::channel();
let call = f(sender);
let Ok(_) = self.server_call_sender.send(call) else {
return None;
};
let Ok(result) = receiver.await else {
return None;
};
Some(result)
}
pub async fn graceful_shutdown(&self) -> Result<(), GracefulShutdownError> {
let (ack_tx, ack_rx) = oneshot::channel();
self.shutdown_sender
.send(ack_tx)
.map_err(|_| GracefulShutdownError::ServerStopped)?;
ack_rx.await.map_err(|_| GracefulShutdownError::ServerStopped)
}
}
#[derive(Debug)]
pub enum GracefulShutdownError {
ServerStopped,
}
impl std::fmt::Display for GracefulShutdownError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ServerStopped => write!(f, "server actor has already stopped"),
}
}
}
impl std::error::Error for GracefulShutdownError {}
impl<E: ServerExt> std::clone::Clone for Server<E> {
fn clone(&self) -> Self {
Self {
connection_sender: self.connection_sender.clone(),
disconnection_sender: self.disconnection_sender.clone(),
server_call_sender: self.server_call_sender.clone(),
shutdown_sender: self.shutdown_sender.clone(),
}
}
}