use std::{fmt::Debug, net::SocketAddr, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex,
};
use tokio_tungstenite::{WebSocketStream, tungstenite::Message};
use crate::{
conn::{ConnState, Connection},
room::{RoomEvents, RoomMethods},
};
#[derive(Debug)]
pub struct ConnectionHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Debug + 'static,
{
pub(crate) id: u64,
pub(crate) writer: Arc<Mutex<futures::stream::SplitSink<WebSocketStream<T>, Message>>>,
pub(crate) addr: SocketAddr,
pub broadcast: Broadcaster<T>,
pub(crate) state: Arc<Mutex<ConnState>>,
pub(crate) room_sender: tokio::sync::mpsc::Sender<RoomEvents<T>>,
}
impl<T> Clone for ConnectionHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
fn clone(&self) -> Self {
Self {
id: self.id,
writer: self.writer.clone(),
addr: self.addr,
broadcast: self.broadcast.clone(),
state: self.state.clone(),
room_sender: self.room_sender.clone(),
}
}
}
impl<T> ConnectionHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub fn id(&self) -> u64 {
self.id
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn state(&self) -> ConnState {
let s = self.state.lock().await;
s.clone()
}
pub async fn send_text<S>(&self, text: S) -> Result<(), Box<dyn std::error::Error>>
where
S: Into<String>,
{
let text = text.into();
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Text(text.into())).await?;
Ok(())
}
pub async fn join(&self, room: &str) -> Result<(), Box<dyn std::error::Error>> {
self.room_sender
.send(RoomEvents::JoinRoom {
client_id: self.id,
handle: self.clone(),
room_name: room.to_string(),
})
.await
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to join room: {}", e),
)
})?;
Ok(())
}
pub async fn leave(&self, room: &str) -> Result<(), Box<dyn std::error::Error>> {
self.room_sender
.send(RoomEvents::LeaveRoom {
client_id: self.id,
room_name: room.to_string(),
})
.await
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to leave room: {}", e),
)
})?;
Ok(())
}
pub fn to(&self, room_name: &str) -> RoomMethods<T> {
RoomMethods {
room_name: room_name.to_string(),
id: self.id,
room_sender: Arc::new(&self.room_sender),
}
}
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), Box<dyn std::error::Error>> {
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Binary(data.into())).await?;
Ok(())
}
pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
{
let mut s = self.state.lock().await;
*s = ConnState::CLOSING;
}
let mut writer = self.writer.lock().await;
futures::SinkExt::send(&mut *writer, Message::Close(None)).await?;
Ok(())
}
}
#[derive(Debug)]
pub struct Broadcaster<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Debug + 'static,
{
pub(crate) current_client_id: u64,
pub(crate) clients: Arc<Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
}
impl<T> Clone for Broadcaster<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
fn clone(&self) -> Self {
Self {
current_client_id: self.current_client_id,
clients: self.clients.clone(),
}
}
}
impl<T> Broadcaster<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub async fn text<S>(&self, text: S)
where
S: Into<String>,
{
let payload: String = text.into();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients
.iter()
.filter_map(|(_, h)| (h.id() != self.current_client_id).then(|| Arc::clone(h)))
.collect()
};
for h in recipients {
if let Err(e) = h.send_text(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
pub async fn emit_text<S>(&self, text: S)
where
S: Into<String>,
{
let payload: String = text.into();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients.iter().map(|(_, h)| Arc::clone(h)).collect()
};
for h in recipients {
if let Err(e) = h.send_text(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
pub async fn emit_binary<B>(&self, bytes: B)
where
B: Into<Vec<u8>>,
{
let payload = bytes.into();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients.iter().map(|(_, h)| Arc::clone(h)).collect()
};
for h in recipients {
if let Err(e) = h.send_binary(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
pub async fn binary<B>(&self, bytes: B)
where
B: Into<Vec<u8>>,
{
let payload = bytes.into();
let recipients: Vec<Arc<ConnectionHandle<T>>> = {
let clients = self.clients.lock().await;
clients
.iter()
.filter_map(|(_, h)| (h.id() != self.current_client_id).then(|| Arc::clone(h)))
.collect()
};
for h in recipients {
if let Err(e) = h.send_binary(payload.clone()).await {
eprintln!("Failed to broadcast to client {}: {}", h.id(), e);
}
}
}
}