use crate::socket::InMessage;
use crate::CloseFrame;
use crate::Error;
use crate::Message;
use crate::Request;
use crate::Session;
use crate::SessionExt;
use crate::Socket;
use async_trait::async_trait;
use std::net::SocketAddr;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
struct NewConnection {
socket: Socket,
address: SocketAddr,
request: Request,
}
struct Disconnected<E: ServerExt> {
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
}
struct ServerActor<E: ServerExt> {
connection_receiver: mpsc::UnboundedReceiver<NewConnection>,
disconnection_receiver: mpsc::UnboundedReceiver<Disconnected<E>>,
server_call_receiver: mpsc::UnboundedReceiver<E::Call>,
server: Server<E>,
extension: E,
}
impl<E: ServerExt> ServerActor<E>
where
E: Send + 'static,
<E::Session as SessionExt>::ID: Send,
{
async fn run(mut self) {
tracing::info!("starting websocket server");
loop {
if let Err(err) = async {
tokio::select! {
Some(NewConnection{socket, address, request}) = self.connection_receiver.recv() => {
let socket_sink = socket.sink.clone();
match self.extension.on_connect(socket, request, address).await {
Ok(session) => {
tracing::info!("connection from {address} accepted");
let session_id = session.id.clone();
tokio::spawn({
let server = self.server.clone();
async move {
let result = session.await_close().await;
server.disconnected(session_id, result);
}
});
}
Err(err) => {
tracing::info!(?err, "connection from {address} rejected");
if let Err(err) = socket_sink.send(InMessage::new(Message::Close(err))).await {
tracing::warn!(?err, "failed forwarding close frame to socket after connection rejected");
}
}
}
}
Some(Disconnected{id, result}) = self.disconnection_receiver.recv() => {
match &result {
Ok(Some(CloseFrame { code, reason })) => {
tracing::info!(%id, ?code, %reason, "connection closed")
}
Ok(None) => tracing::info!(%id, "connection closed"),
Err(err) => tracing::warn!(%id, "connection closed due to: {err:?}"),
};
self.extension.on_disconnect(id.clone(), result).await?;
}
Some(call) = self.server_call_receiver.recv() => {
self.extension.on_call(call).await?
}
else => return Err("server actor branches broken".into()),
}
Ok::<_, Error>(())
}
.await {
tracing::warn!("error when processing: {err:?}");
}
}
}
}
#[async_trait]
pub trait ServerExt: Send {
type Session: SessionExt;
type Call: Send;
async fn on_connect(
&mut self,
socket: Socket,
request: Request,
address: SocketAddr,
) -> Result<
Session<<Self::Session as SessionExt>::ID, <Self::Session as SessionExt>::Call>,
Option<CloseFrame>,
>;
async fn on_disconnect(
&mut self,
id: <Self::Session as SessionExt>::ID,
reason: Result<Option<CloseFrame>, Error>,
) -> Result<(), Error>;
async fn on_call(&mut self, call: Self::Call) -> Result<(), Error>;
}
#[derive(Debug)]
pub struct Server<E: ServerExt> {
connection_sender: mpsc::UnboundedSender<NewConnection>,
disconnection_sender: mpsc::UnboundedSender<Disconnected<E>>,
server_call_sender: mpsc::UnboundedSender<E::Call>,
}
impl<E: ServerExt> From<Server<E>> for mpsc::UnboundedSender<E::Call> {
fn from(server: Server<E>) -> Self {
server.server_call_sender
}
}
impl<E: ServerExt + 'static> Server<E> {
pub fn create(create: impl FnOnce(Self) -> E) -> (Self, JoinHandle<()>) {
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();
let (disconnection_sender, disconnection_receiver) = mpsc::unbounded_channel();
let (server_call_sender, server_call_receiver) = mpsc::unbounded_channel();
let handle = Self {
connection_sender,
server_call_sender,
disconnection_sender,
};
let extension = create(handle.clone());
let actor = ServerActor {
connection_receiver,
disconnection_receiver,
server_call_receiver,
extension,
server: handle.clone(),
};
let future = tokio::spawn(actor.run());
(handle, future)
}
}
impl<E: ServerExt> Server<E> {
pub fn accept(&self, socket: Socket, request: Request, address: SocketAddr) {
if self
.connection_sender
.send(NewConnection {
socket,
request,
address,
})
.is_err()
{
tracing::error!("accepted a connection but the server actor is dead");
}
}
pub(crate) fn disconnected(
&self,
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
) {
self.disconnection_sender
.send(Disconnected { id, result })
.map_err(|_| ())
.unwrap_or_default();
}
pub fn call(&self, call: E::Call) -> Result<(), mpsc::error::SendError<E::Call>> {
self.server_call_sender.send(call)
}
pub async fn call_with<R: std::fmt::Debug>(
&self,
f: impl FnOnce(oneshot::Sender<R>) -> E::Call,
) -> Option<R> {
let (sender, receiver) = oneshot::channel();
let call = f(sender);
let Ok(_) = self.server_call_sender.send(call) else {
return None;
};
let Ok(result) = receiver.await else {
return None;
};
Some(result)
}
}
impl<E: ServerExt> std::clone::Clone for Server<E> {
fn clone(&self) -> Self {
Self {
connection_sender: self.connection_sender.clone(),
disconnection_sender: self.disconnection_sender.clone(),
server_call_sender: self.server_call_sender.clone(),
}
}
}