use std::{fmt::Debug, future::Future, net::SocketAddr, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
use crate::{
handle::ConnectionHandle,
room::RoomEvents,
types::{BinaryMessageEvent, CloseEvent, TextMessageEvent},
wynd::BoxFuture,
};
type CloseHandler =
Arc<tokio::sync::Mutex<Option<Box<dyn Fn(CloseEvent) -> BoxFuture<()> + Send + Sync>>>>;
type TextMessageHandler<T> = Arc<
tokio::sync::Mutex<
Option<
Box<dyn Fn(TextMessageEvent, Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync>,
>,
>,
>;
type BinaryMessageHandler<T> = Arc<
tokio::sync::Mutex<
Option<
Box<
dyn Fn(BinaryMessageEvent, Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync,
>,
>,
>,
>;
type OpenHandler<T> = Arc<
tokio::sync::Mutex<
Option<Box<dyn Fn(Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync>>,
>,
>;
pub struct Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
id: u64,
reader: Arc<tokio::sync::Mutex<futures::stream::SplitStream<WebSocketStream<T>>>>,
pub(crate) writer:
Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocketStream<T>, Message>>>,
addr: SocketAddr,
open_handler: OpenHandler<T>,
text_message_handler: TextMessageHandler<T>,
binary_message_handler: BinaryMessageHandler<T>,
close_handler: CloseHandler,
pub(crate) state: Arc<tokio::sync::Mutex<ConnState>>,
clients: Arc<tokio::sync::Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
pub(crate) handle: Arc<tokio::sync::Mutex<Option<Arc<ConnectionHandle<T>>>>>,
}
impl<T> std::fmt::Debug for Connection<T>
where
T: AsyncRead + AsyncWrite + Debug + Unpin + Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("id", &self.id)
.field("addr", &self.addr)
.finish()
}
}
#[derive(Clone, Debug)]
pub enum ConnState {
OPEN,
CLOSED,
CONNECTING,
CLOSING,
}
impl std::fmt::Display for ConnState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnState::OPEN => write!(f, "OPEN"),
ConnState::CLOSED => write!(f, "CLOSED"),
ConnState::CONNECTING => write!(f, "CONNECTING"),
ConnState::CLOSING => write!(f, "CLOSING"),
}
}
}
impl<T> Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub(crate) fn new(id: u64, websocket: WebSocketStream<T>, addr: SocketAddr) -> Self {
let (writer, reader) = futures::StreamExt::split(websocket);
Self {
id,
state: Arc::new(tokio::sync::Mutex::new(ConnState::CONNECTING)),
reader: Arc::new(tokio::sync::Mutex::new(reader)),
writer: Arc::new(tokio::sync::Mutex::new(writer)),
addr,
open_handler: Arc::new(tokio::sync::Mutex::new(None)),
text_message_handler: Arc::new(tokio::sync::Mutex::new(None)),
binary_message_handler: Arc::new(tokio::sync::Mutex::new(None)),
close_handler: Arc::new(tokio::sync::Mutex::new(None)),
clients: Arc::new(tokio::sync::Mutex::new(Vec::new())),
handle: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub(crate) fn set_clients_registry(
&mut self,
clients: Arc<tokio::sync::Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
) {
self.clients = clients;
}
pub(crate) async fn set_handle(&self, handle: Arc<ConnectionHandle<T>>) {
let mut h = self.handle.lock().await;
*h = Some(handle);
}
pub fn id(&self) -> u64 {
self.id
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn state(&self) -> ConnState {
let s = self.state.lock().await;
s.clone()
}
pub async fn on_open<F, Fut>(&self, handler: F)
where
F: Fn(Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut open_handler = self.open_handler.lock().await;
*open_handler = Some(Box::new(move |handle| Box::pin(handler(handle))));
let handle = {
if let Some(h) = self.handle.lock().await.clone() {
h
} else {
let (tx, _rx) = tokio::sync::mpsc::channel::<RoomEvents<T>>(1);
let (response_tx, response_rx) = tokio::sync::mpsc::channel::<Vec<&'static str>>(1);
Arc::new(crate::handle::ConnectionHandle {
id: self.id,
writer: Arc::clone(&self.writer),
addr: self.addr,
broadcast: crate::handle::Broadcaster {
clients: Arc::clone(&self.clients),
current_client_id: self.id,
},
state: Arc::clone(&self.state),
room_sender: Arc::new(tx),
response_sender: Arc::new(response_tx),
response_receiver: Arc::new(tokio::sync::Mutex::new(response_rx)),
})
}
};
let open_handler_clone = Arc::clone(&self.open_handler);
let text_message_handler_clone = Arc::clone(&self.text_message_handler);
let binary_message_handler_clone = Arc::clone(&self.binary_message_handler);
let close_handler_clone = Arc::clone(&self.close_handler);
let handle_clone = Arc::clone(&handle);
let reader_clone = Arc::clone(&self.reader);
let state_clone = Arc::clone(&self.state);
tokio::spawn(async move {
{
{
let mut s = state_clone.lock().await;
*s = ConnState::OPEN;
}
{
let handler_fut = {
let open_handler = open_handler_clone.lock().await;
if let Some(ref handler) = *open_handler {
Some(handler(Arc::clone(&handle_clone)))
} else {
None
}
};
if let Some(fut) = handler_fut {
fut.await;
}
}
}
Self::message_loop(
handle_clone,
text_message_handler_clone,
binary_message_handler_clone,
close_handler_clone,
reader_clone,
state_clone,
)
.await;
});
}
pub fn on_binary<F, Fut>(&self, handler: F)
where
F: Fn(BinaryMessageEvent, Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let binary_message_handler = Arc::clone(&self.binary_message_handler);
tokio::spawn(async move {
let mut lock = binary_message_handler.lock().await;
*lock = Some(Box::new(move |msg, handle| Box::pin(handler(msg, handle))));
});
}
pub fn on_text<F, Fut>(&self, handler: F)
where
F: Fn(TextMessageEvent, Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let text_message_handler = Arc::clone(&self.text_message_handler);
tokio::spawn(async move {
let mut lock = text_message_handler.lock().await;
*lock = Some(Box::new(move |msg, handle| Box::pin(handler(msg, handle))));
});
}
pub fn on_close<F, Fut>(&self, handler: F)
where
F: Fn(CloseEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let close_handler = Arc::clone(&self.close_handler);
tokio::spawn(async move {
let mut lock = close_handler.lock().await;
*lock = Some(Box::new(move |event| Box::pin(handler(event))));
});
}
async fn message_loop(
handle: Arc<ConnectionHandle<T>>,
text_message_handler: TextMessageHandler<T>,
binary_message_handler: BinaryMessageHandler<T>,
close_handler: CloseHandler,
reader: Arc<tokio::sync::Mutex<futures::stream::SplitStream<WebSocketStream<T>>>>,
state: Arc<tokio::sync::Mutex<ConnState>>,
) {
loop {
let msg = {
let mut rd = reader.lock().await;
let msg = futures::StreamExt::next(&mut *rd).await;
msg
};
match msg {
Some(Ok(Message::Text(text))) => {
let handler_fut = {
let handler = text_message_handler.lock().await;
if let Some(ref h) = *handler {
Some(h(
TextMessageEvent::new(text.to_string()),
Arc::clone(&handle),
))
} else {
None
}
};
if let Some(fut) = handler_fut {
fut.await;
}
}
Some(Ok(Message::Ping(payload))) => {
let mut w = handle.writer.lock().await;
let _ = futures::SinkExt::send(&mut *w, Message::Pong(payload)).await;
}
Some(Ok(Message::Pong(_))) => {
}
Some(Ok(Message::Binary(data))) => {
let handler_fut = {
let handler = binary_message_handler.lock().await;
if let Some(ref h) = *handler {
Some(h(
BinaryMessageEvent::new(data.to_vec()),
Arc::clone(&handle),
))
} else {
None
}
};
if let Some(fut) = handler_fut {
fut.await;
}
}
Some(Ok(Message::Close(close_frame))) => {
let close_event = match close_frame {
Some(e) => CloseEvent::new(e.code.into(), e.reason.to_string()),
None => CloseEvent::new(1005, "No status received".to_string()),
};
let handler_fut = {
let handler = close_handler.lock().await;
if let Some(ref h) = *handler {
Some(h(close_event))
} else {
None
}
};
if let Some(fut) = handler_fut {
fut.await;
}
{
let mut s = state.lock().await;
*s = ConnState::CLOSED;
}
break;
}
Some(Err(e)) => {
eprintln!("WebSocket error: {}", e);
{
let mut s = state.lock().await;
*s = ConnState::CLOSED;
}
break;
}
_ => {}
}
}
}
}