use std::{fmt::Debug, net::SocketAddr, sync::Arc};
use futures::FutureExt;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex,
};
use tokio_tungstenite::{WebSocketStream, tungstenite::Message};
use crate::{
types::{BinaryMessageEvent, CloseEvent, TextMessageEvent},
wynd::BoxFuture,
};
type CloseHandler = Arc<Mutex<Option<Box<dyn Fn(CloseEvent) -> BoxFuture<()> + Send + Sync>>>>;
type TextMessageHandler<T> = Arc<
Mutex<
Option<
Box<dyn Fn(TextMessageEvent, Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync>,
>,
>,
>;
type BinaryMessageHandler<T> = Arc<
Mutex<
Option<
Box<
dyn Fn(BinaryMessageEvent, Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync,
>,
>,
>,
>;
type OpenHandler<T> =
Arc<Mutex<Option<Box<dyn Fn(Arc<ConnectionHandle<T>>) -> BoxFuture<()> + Send + Sync>>>>;
pub struct Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
id: u64,
reader: Arc<Mutex<futures::stream::SplitStream<WebSocketStream<T>>>>,
pub(crate) writer: Arc<Mutex<futures::stream::SplitSink<WebSocketStream<T>, Message>>>,
addr: SocketAddr,
open_handler: OpenHandler<T>,
text_message_handler: TextMessageHandler<T>,
binary_message_handler: BinaryMessageHandler<T>,
close_handler: CloseHandler,
pub(crate) state: Arc<Mutex<ConnState>>,
clients: Arc<Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
}
impl<T> std::fmt::Debug for Connection<T>
where
T: AsyncRead + AsyncWrite + Debug + Unpin + Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("id", &self.id)
.field("addr", &self.addr)
.finish()
}
}
#[derive(Clone, Debug)]
pub enum ConnState {
OPEN,
CLOSED,
CONNECTING,
CLOSING,
}
impl std::fmt::Display for ConnState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnState::OPEN => write!(f, "OPEN"),
ConnState::CLOSED => write!(f, "CLOSED"),
ConnState::CONNECTING => write!(f, "CONNECTING"),
ConnState::CLOSING => write!(f, "CLOSING"),
}
}
}
#[derive(Debug)]
pub struct ConnectionHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Debug + 'static,
{
pub(crate) id: u64,
pub(crate) writer: Arc<Mutex<futures::stream::SplitSink<WebSocketStream<T>, Message>>>,
pub(crate) addr: SocketAddr,
pub broadcast: Broadcaster<T>,
pub(crate) state: Arc<Mutex<ConnState>>,
}
#[derive(Debug)]
pub struct Broadcaster<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Debug + 'static,
{
pub(crate) current_client_id: u64,
pub(crate) clients: Arc<Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
}
impl<T> Broadcaster<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub async fn text(&self, text: &str) {
for client in self.clients.lock().await.iter() {
if client.1.id == self.current_client_id {
continue;
} else {
if let Err(e) = client.1.send_text(text).await {
eprintln!("Failed to broadcast to client {}: {}", client.1.id(), e);
}
}
}
}
pub async fn emit_text(&self, text: &str) {
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients.iter().map(|(_, h)| Arc::clone(h)).collect()
};
for h in recipients {
if let Err(e) = h.send_text(text).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
pub async fn emit_binary(&self, bytes: &[u8]) {
let payload = bytes.to_vec();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients.iter().map(|(_, h)| Arc::clone(h)).collect()
};
for h in recipients {
if let Err(e) = h.send_binary(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
pub async fn binary(&self, bytes: &[u8]) {
let payload = bytes.to_vec();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients
.iter()
.filter_map(|(_, h)| (h.id() != self.current_client_id).then(|| Arc::clone(h)))
.collect()
};
for h in recipients {
if let Err(e) = h.send_binary(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
}
impl<T> Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub(crate) fn new(id: u64, websocket: WebSocketStream<T>, addr: SocketAddr) -> Self {
let (writer, reader) = futures::StreamExt::split(websocket);
Self {
id,
state: Arc::new(Mutex::new(ConnState::CONNECTING)),
reader: Arc::new(Mutex::new(reader)),
writer: Arc::new(Mutex::new(writer)),
addr,
open_handler: Arc::new(Mutex::new(None)),
text_message_handler: Arc::new(Mutex::new(None)),
binary_message_handler: Arc::new(Mutex::new(None)),
close_handler: Arc::new(Mutex::new(None)),
clients: Arc::new(Mutex::new(Vec::new())),
}
}
pub(crate) fn set_clients_registry(
&mut self,
clients: Arc<Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
) {
self.clients = clients;
}
pub fn id(&self) -> u64 {
self.id
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn state(&self) -> ConnState {
let s = self.state.lock().await;
s.clone()
}
pub async fn on_open<F, Fut>(&self, handler: F)
where
F: Fn(Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut open_handler: tokio::sync::MutexGuard<
'_,
Option<
Box<
dyn Fn(
Arc<ConnectionHandle<T>>,
)
-> std::pin::Pin<Box<dyn Future<Output = ()> + Send>>
+ Send
+ Sync,
>,
>,
> = self.open_handler.lock().await;
*open_handler = Some(Box::new(move |handle| Box::pin(handler(handle))));
let broadcaster = Broadcaster {
clients: Arc::clone(&self.clients),
current_client_id: self.id,
};
let handle = Arc::new(ConnectionHandle {
id: self.id,
writer: Arc::clone(&self.writer),
addr: self.addr,
broadcast: broadcaster,
state: Arc::clone(&self.state),
});
let open_handler_clone = Arc::clone(&self.open_handler);
let text_message_handler_clone = Arc::clone(&self.text_message_handler);
let binary_message_handler_clone = Arc::clone(&self.binary_message_handler);
let close_handler_clone = Arc::clone(&self.close_handler);
let handle_clone = Arc::clone(&handle);
let reader_clone = Arc::clone(&self.reader);
let state_clone = Arc::clone(&self.state);
tokio::spawn(async move {
{
{
let mut s = state_clone.lock().await;
*s = ConnState::OPEN;
}
let open_handler = open_handler_clone.lock().await;
if let Some(ref handler) = *open_handler {
handler(Arc::clone(&handle_clone)).await;
}
}
Self::message_loop(
handle_clone,
text_message_handler_clone,
binary_message_handler_clone,
close_handler_clone,
reader_clone,
state_clone,
)
.await;
});
}
pub fn on_binary<F, Fut>(&self, handler: F)
where
F: Fn(BinaryMessageEvent, Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let binary_message_handler = Arc::clone(&self.binary_message_handler);
tokio::spawn(async move {
let mut lock = binary_message_handler.lock().await;
*lock = Some(Box::new(move |msg, handle| Box::pin(handler(msg, handle))));
});
}
pub fn on_text<F, Fut>(&self, handler: F)
where
F: Fn(TextMessageEvent, Arc<ConnectionHandle<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let text_message_handler = Arc::clone(&self.text_message_handler);
tokio::task::block_in_place(|| {}); let text_message_handler_fut = async move {
let mut lock = text_message_handler.lock().await;
*lock = Some(Box::new(move |msg, handle| Box::pin(handler(msg, handle))));
};
text_message_handler_fut.now_or_never();
}
pub fn on_close<F, Fut>(&self, handler: F)
where
F: Fn(CloseEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let close_handler = Arc::clone(&self.close_handler);
tokio::spawn(async move {
let mut lock = close_handler.lock().await;
*lock = Some(Box::new(move |event| Box::pin(handler(event))));
});
}
async fn message_loop(
handle: Arc<ConnectionHandle<T>>,
text_message_handler: TextMessageHandler<T>,
binary_message_handler: BinaryMessageHandler<T>,
close_handler: CloseHandler,
reader: Arc<Mutex<futures::stream::SplitStream<WebSocketStream<T>>>>,
state: Arc<Mutex<ConnState>>,
) {
loop {
let msg = {
let mut rd = reader.lock().await;
futures::StreamExt::next(&mut *rd).await
};
match msg {
Some(Ok(Message::Text(text))) => {
let handler = text_message_handler.lock().await;
if let Some(ref h) = *handler {
h(TextMessageEvent::new(text.to_string()), Arc::clone(&handle)).await;
}
}
Some(Ok(Message::Ping(payload))) => {
let mut w = handle.writer.lock().await;
let _ = futures::SinkExt::send(&mut *w, Message::Pong(payload)).await;
}
Some(Ok(Message::Pong(_))) => {
}
Some(Ok(Message::Binary(data))) => {
let handler = binary_message_handler.lock().await;
if let Some(ref h) = *handler {
h(BinaryMessageEvent::new(data.to_vec()), Arc::clone(&handle)).await;
}
}
Some(Ok(Message::Close(close_frame))) => {
let close_event = match close_frame {
Some(e) => CloseEvent::new(e.code.into(), e.reason.to_string()),
None => CloseEvent::new(1005, "No status received".to_string()),
};
let handler = close_handler.lock().await;
if let Some(ref h) = *handler {
h(close_event).await;
}
{
let mut s = state.lock().await;
*s = ConnState::CLOSED;
}
break;
}
Some(Err(e)) => {
eprintln!("WebSocket error: {}", e);
{
let mut s = state.lock().await;
*s = ConnState::CLOSED;
}
break;
}
_ => {}
}
}
}
}
impl<T> ConnectionHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub fn id(&self) -> u64 {
self.id
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn state(&self) -> ConnState {
let s = self.state.lock().await;
s.clone()
}
pub async fn send_text(&self, text: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Text(text.into())).await?;
Ok(())
}
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), Box<dyn std::error::Error>> {
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Binary(data.into())).await?;
Ok(())
}
pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
{
let mut s = self.state.lock().await;
*s = ConnState::CLOSING;
}
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Close(None)).await?;
Ok(())
}
}