#[cfg(feature = "with-ripress")]
use hyper_tungstenite::hyper;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::Mutex;
use tokio::sync::mpsc::Receiver;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use crate::conn::Connection;
use crate::handle::{Broadcaster, ConnectionHandle};
use crate::room::{Room, RoomEvents};
use crate::types::WyndError;
use std::fmt::Debug;
pub(crate) type ConnectionId = AtomicU64;
pub(crate) type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
pub struct Wynd<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
pub(crate) connection_handler:
Option<Box<dyn Fn(Arc<Connection<T>>) -> BoxFuture<()> + Send + Sync + 'static>>,
pub(crate) addr: SocketAddr,
pub(crate) error_handler:
Option<Box<dyn Fn(WyndError) -> BoxFuture<()> + Send + Sync + 'static>>,
pub(crate) close_handler: Option<Box<dyn Fn() -> () + Send + Sync + 'static>>,
pub(crate) next_connection_id: ConnectionId,
pub clients: Arc<tokio::sync::Mutex<Vec<(Arc<Connection<T>>, Arc<ConnectionHandle<T>>)>>>,
pub rooms: Arc<tokio::sync::Mutex<Vec<Room<T>>>>,
room_sender: Arc<tokio::sync::mpsc::Sender<RoomEvents<T>>>,
_room_receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<RoomEvents<T>>>>,
}
impl<T> Debug for Wynd<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wynd").finish()
}
}
pub type Standalone = TcpStream;
#[cfg(feature = "with-ripress")]
pub type WithRipress = hyper::upgrade::Upgraded;
impl<T> Drop for Wynd<T>
where
T: AsyncRead + AsyncWrite + Unpin + Debug + Send + 'static,
{
fn drop(&mut self) {
let close_handler = match self.close_handler.as_ref() {
None => return,
Some(handler) => handler,
};
close_handler();
}
}
impl<T> Wynd<T>
where
T: AsyncRead + Debug + AsyncWrite + Send + 'static + Unpin,
{
pub fn new() -> Self {
let (room_sender, room_receiver) = tokio::sync::mpsc::channel(100);
Self {
connection_handler: None,
error_handler: None,
close_handler: None,
next_connection_id: ConnectionId::new(0),
clients: Arc::new(tokio::sync::Mutex::new(Vec::new())),
addr: SocketAddr::from(([0, 0, 0, 0], 8080)),
rooms: Arc::new(tokio::sync::Mutex::new(Vec::new())),
room_sender: Arc::new(room_sender),
_room_receiver: Arc::new(Mutex::new(room_receiver)),
}
}
pub fn on_connection<F, Fut>(&mut self, handler: F)
where
F: Fn(Arc<Connection<T>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.connection_handler = Some(Box::new(move |conn| Box::pin(handler(conn))));
}
pub fn on_error<F, Fut>(&mut self, handler: F)
where
F: Fn(WyndError) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.error_handler = Some(Box::new(move |err| Box::pin(handler(err))));
}
pub fn on_close<F>(&mut self, handler: F)
where
F: Fn() -> () + Send + Sync + 'static,
{
self.close_handler = Some(Box::new(move || handler()));
}
async fn handle_connection(
&mut self,
stream: T,
addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error>> {
let websocket = match timeout(Duration::from_secs(10), accept_async(stream)).await {
Ok(res) => res?, Err(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"WebSocket handshake timed out",
)
.into());
}
};
let connection_id = self.next_connection_id.fetch_add(1, Ordering::Relaxed);
let mut connection = Connection::new(connection_id, websocket, addr);
connection.set_clients_registry(Arc::clone(&self.clients));
let broadcaster = Broadcaster {
clients: Arc::clone(&self.clients),
current_client_id: connection_id,
};
let (response_sender, response_receiver) = tokio::sync::mpsc::channel(10);
let handle = Arc::new(ConnectionHandle {
id: connection.id(),
writer: Arc::clone(&connection.writer),
addr: addr,
broadcast: broadcaster,
state: Arc::clone(&connection.state),
room_sender: Arc::clone(&self.room_sender),
response_sender: Arc::new(response_sender),
response_receiver: Arc::new(Mutex::new(response_receiver)),
});
let arc_connection = Arc::new(connection);
arc_connection.set_handle(Arc::clone(&handle)).await;
{
let mut clients = self.clients.lock().await;
clients.push((Arc::clone(&arc_connection), Arc::clone(&handle)));
}
{
let clients_registry = Arc::clone(&self.clients);
let rooms_registry = Arc::clone(&self.rooms);
let handle_id = handle.id();
arc_connection.on_close(move |_event| {
let clients_registry = Arc::clone(&clients_registry);
let rooms_registry = Arc::clone(&rooms_registry);
async move {
let mut clients = clients_registry.lock().await;
clients.retain(|(_c, h)| h.id() != handle_id);
let mut rooms = rooms_registry.lock().await;
for room in rooms.iter_mut() {
room.room_clients.remove(&handle_id);
}
rooms.retain(|room| !room.room_clients.is_empty());
}
});
}
arc_connection
.on_open(|_handle| async move {
})
.await;
if let Some(ref handler) = self.connection_handler {
handler(arc_connection).await;
}
Ok(())
}
}
impl Wynd<TcpStream> {
pub async fn listen<F>(
mut self,
port: u16,
on_listening: F,
) -> Result<(), Box<dyn std::error::Error>>
where
F: FnOnce() + Send + 'static,
{
let addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(&addr).await?;
self.addr = listener.local_addr().unwrap();
let (room_sender, room_receiver) = tokio::sync::mpsc::channel::<RoomEvents<TcpStream>>(100);
self.room_sender = Arc::new(room_sender);
let rooms = Arc::clone(&self.rooms);
let clients = Arc::clone(&self.clients);
Self::handle_communication(room_receiver, rooms, clients);
on_listening();
let wynd = Arc::new(Mutex::new(self));
loop {
match listener.accept().await {
Ok((stream, addr)) => {
let wynd_clone = Arc::clone(&wynd);
tokio::spawn(async move {
if let Err(e) = wynd_clone
.lock()
.await
.handle_connection(stream, addr)
.await
{
eprintln!("Error handling connection: {}", e);
}
});
}
Err(e) => {
let wynd_guard = wynd.lock().await;
let handler = wynd_guard.error_handler.as_ref();
if let Some(handler) = handler {
handler(WyndError::new(e.to_string())).await;
} else {
eprintln!("Error accepting connection: {}", e);
}
eprintln!("accept() failed: {e}. Retrying...");
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
}
}
}
fn handle_communication(
mut room_receiver: Receiver<RoomEvents<TcpStream>>,
rooms: Arc<Mutex<Vec<Room<TcpStream>>>>,
clients: Arc<Mutex<Vec<(Arc<Connection<TcpStream>>, Arc<ConnectionHandle<TcpStream>>)>>>,
) {
tokio::spawn(async move {
while let Some(room_data) = room_receiver.recv().await {
match room_data {
RoomEvents::JoinRoom {
client_id,
handle,
room_name,
} => {
let mut rooms = rooms.lock().await;
let maybe_room = rooms.iter_mut().find(|room| room.room_name == room_name);
if let Some(room) = maybe_room {
if room.room_clients.contains_key(&client_id) {
continue;
} else {
room.room_clients.insert(client_id, handle);
}
} else {
let room = Room {
room_clients: HashMap::from([(client_id, handle)]),
room_name,
};
rooms.push(room);
}
}
RoomEvents::TextMessage {
room_name,
text,
client_id,
} => {
let mut rooms = rooms.lock().await;
let maybe_room = rooms.iter_mut().find(|room| room.room_name == room_name);
if maybe_room.is_none() {
return;
}
if !maybe_room.unwrap().room_clients.contains_key(&client_id) {
return;
}
let handles: Vec<_> = {
if let Some(room) = rooms.iter().find(|r| r.room_name == room_name) {
room.room_clients.values().cloned().collect()
} else {
Vec::new()
}
};
if handles.is_empty() {
eprintln!("Room not found: {}", room_name);
} else {
for h in handles {
if h.id == client_id {
continue;
} else {
if let Err(e) = h.send_text(text.clone()).await {
eprintln!("Failed to send text to client: {}", e);
}
}
}
}
}
RoomEvents::BinaryMessage {
room_name,
bytes,
client_id,
} => {
let mut rooms = rooms.lock().await;
let maybe_room = rooms.iter_mut().find(|room| room.room_name == room_name);
if maybe_room.is_none() {
return;
}
if !maybe_room.unwrap().room_clients.contains_key(&client_id) {
return;
}
let recipients = {
rooms
.iter()
.find(|r| r.room_name == room_name)
.map(|r| r.room_clients.values().cloned().collect::<Vec<_>>())
};
if let Some(recipients) = recipients {
for h in recipients {
if h.id == client_id {
continue;
} else {
if let Err(e) = h.send_binary(bytes.clone()).await {
eprintln!("Failed to send binary to client: {}", e);
}
}
}
} else {
println!("Room not found: {}", room_name);
}
}
RoomEvents::EmitTextMessage {
client_id,
room_name,
text,
} => {
let mut rooms = rooms.lock().await;
let maybe_room = rooms.iter_mut().find(|room| room.room_name == room_name);
if maybe_room.is_none() {
return;
}
if !maybe_room.unwrap().room_clients.contains_key(&client_id) {
return;
}
let handles: Vec<_> = {
if let Some(room) = rooms.iter().find(|r| r.room_name == room_name) {
room.room_clients.values().cloned().collect()
} else {
Vec::new()
}
};
if handles.is_empty() {
eprintln!("Room not found: {}", room_name);
} else {
for h in handles {
if let Err(e) = h.send_text(text.clone()).await {
eprintln!("Failed to send text to client: {}", e);
}
}
}
}
RoomEvents::EmitBinaryMessage {
client_id,
room_name,
bytes,
} => {
let mut rooms = rooms.lock().await;
let maybe_room = rooms.iter_mut().find(|room| room.room_name == room_name);
if maybe_room.is_none() {
return;
}
if !maybe_room.unwrap().room_clients.contains_key(&client_id) {
return;
}
let recipients = {
rooms
.iter()
.find(|r| r.room_name == room_name)
.map(|r| r.room_clients.values().cloned().collect::<Vec<_>>())
};
if let Some(recipients) = recipients {
for h in recipients {
if let Err(e) = h.send_binary(bytes.clone()).await {
eprintln!("Failed to send binary to client: {}", e);
}
}
} else {
println!("Room not found: {}", room_name);
}
}
RoomEvents::LeaveRoom {
client_id,
room_name,
} => {
let mut rooms_guard = rooms.lock().await;
let mut remove_room = false;
if let Some(room) = rooms_guard
.iter_mut()
.find(|room| room.room_name == room_name)
{
room.room_clients.remove(&client_id);
remove_room = room.room_clients.is_empty();
}
if remove_room {
rooms_guard.retain(|r| r.room_name != room_name);
}
}
RoomEvents::ListRooms { client_id } => {
let rooms_guard = rooms.lock().await;
let mut list = Vec::new();
for room in rooms_guard.iter() {
list.push(room.room_name.clone());
}
let clients_guard = clients.lock().await;
if let Some((_, handle)) =
clients_guard.iter().find(|(_, h)| h.id() == client_id)
{
if let Err(e) = handle.response_sender.send(list).await {
eprintln!(
"Failed to send list rooms response to client {}: {}",
client_id, e
);
}
} else {
eprintln!("Client {} not found for list rooms response", client_id);
}
}
RoomEvents::ListRoomsResponse {
client_id: _,
rooms: _,
} => {}
RoomEvents::LeaveAllRooms { client_id } => {
let mut rooms_guard = rooms.lock().await;
let mut rooms_to_remove = Vec::new();
for (index, room) in rooms_guard.iter_mut().enumerate() {
if room.room_clients.contains_key(&client_id) {
room.room_clients.remove(&client_id);
if room.room_clients.is_empty() {
rooms_to_remove.push(index);
}
}
}
for index in rooms_to_remove.iter().rev() {
rooms_guard.remove(*index);
}
}
}
}
});
}
}
#[cfg(feature = "with-ripress")]
impl Wynd<WithRipress> {
pub fn handler(
self,
) -> impl Fn(
hyper::Request<hyper::Body>,
)
-> Pin<Box<dyn Future<Output = hyper::Result<hyper::Response<hyper::Body>>> + Send>>
+ Send
+ Sync
+ 'static {
let wynd = Arc::new(self);
move |mut req| {
let wynd = Arc::clone(&wynd);
Box::pin(async move {
let is_websocket_upgrade = req
.headers()
.get("upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let has_websocket_key = req.headers().get("sec-websocket-key").is_some();
let has_websocket_version = req.headers().get("sec-websocket-version").is_some();
if !is_websocket_upgrade || !has_websocket_key || !has_websocket_version {
let response = hyper::Response::builder()
.status(400)
.body(hyper::Body::from("Expected WebSocket upgrade"))
.unwrap();
return Ok(response);
}
match hyper_tungstenite::upgrade(&mut req, None) {
Ok((response, websocket_future)) => {
let wynd_clone = Arc::clone(&wynd);
tokio::spawn(async move {
match websocket_future.await {
Ok(ws_stream) => {
let connection_id = wynd_clone
.next_connection_id
.fetch_add(1, Ordering::Relaxed);
let mut connection =
Connection::new(connection_id, ws_stream, wynd_clone.addr);
connection
.set_clients_registry(Arc::clone(&wynd_clone.clients));
let broadcaster = Broadcaster {
clients: Arc::clone(&wynd_clone.clients),
current_client_id: connection_id,
};
let (response_sender, response_receiver) =
tokio::sync::mpsc::channel(10);
let handle = Arc::new(ConnectionHandle {
id: connection.id(),
writer: Arc::clone(&connection.writer),
addr: wynd_clone.addr,
broadcast: broadcaster,
state: Arc::clone(&connection.state),
room_sender: wynd_clone.room_sender.clone(),
response_sender: Arc::new(response_sender),
response_receiver: Arc::new(Mutex::new(response_receiver)),
});
let arc_connection = Arc::new(connection);
arc_connection.set_handle(Arc::clone(&handle)).await;
{
let mut clients = wynd_clone.clients.lock().await;
clients.push((
Arc::clone(&arc_connection),
Arc::clone(&handle),
));
}
{
let clients_registry = Arc::clone(&wynd_clone.clients);
let handle_id = handle.id();
arc_connection.on_close(move |_event| {
let clients_registry = Arc::clone(&clients_registry);
async move {
let mut clients = clients_registry.lock().await;
clients.retain(|(_c, h)| h.id() != handle_id);
}
});
}
if let Err(e) = wynd_clone
.handle_websocket_connection(Arc::clone(&arc_connection))
.await
{
eprintln!("Error handling WebSocket connection: {}", e);
if let Some(ref _error_handler) = wynd_clone.error_handler {
}
}
}
Err(e) => {
eprintln!("WebSocket handshake failed: {:?}", e);
}
}
});
Ok(response)
}
Err(e) => {
eprintln!("WebSocket upgrade failed: {:?}", e);
let response = hyper::Response::builder()
.status(400)
.body(hyper::Body::from("WebSocket upgrade failed"))
.unwrap();
Ok(response)
}
}
})
}
}
async fn handle_websocket_connection(
&self,
connection: Arc<Connection<WithRipress>>,
) -> Result<(), Box<dyn std::error::Error>> {
connection.on_open(|_handle| async move {}).await;
if let Some(ref handler) = self.connection_handler {
handler(connection).await;
}
Ok(())
}
}