#[cfg(feature = "websocket")]
use futures_util::{SinkExt, StreamExt};
#[cfg(feature = "websocket")]
use std::collections::HashMap;
#[cfg(feature = "websocket")]
use std::net::SocketAddr;
#[cfg(feature = "websocket")]
use std::sync::Arc;
#[cfg(feature = "websocket")]
use tokio::net::{TcpListener, TcpStream};
#[cfg(feature = "websocket")]
use tokio::sync::{Mutex, RwLock, broadcast};
#[cfg(feature = "websocket")]
use tokio_tungstenite::{WebSocketStream, accept_async, tungstenite::Message};
#[cfg(feature = "websocket")]
use crate::shutdown::ShutdownCoordinator;
#[cfg(feature = "websocket")]
type WsWriter = futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>;
#[cfg(feature = "websocket")]
struct Client {
#[allow(dead_code)]
addr: SocketAddr,
sender: Arc<Mutex<WsWriter>>,
}
#[cfg(feature = "websocket")]
#[derive(Clone)]
pub struct BroadcastManager {
clients: Arc<RwLock<HashMap<SocketAddr, Arc<Client>>>>,
broadcast_tx: broadcast::Sender<String>,
}
#[cfg(feature = "websocket")]
impl BroadcastManager {
pub fn new(capacity: usize) -> Self {
let (broadcast_tx, _) = broadcast::channel(capacity);
Self {
clients: Arc::new(RwLock::new(HashMap::new())),
broadcast_tx,
}
}
async fn register_client(&self, addr: SocketAddr, sender: WsWriter) {
let client = Arc::new(Client {
addr,
sender: Arc::new(Mutex::new(sender)),
});
self.clients.write().await.insert(addr, client);
}
async fn unregister_client(&self, addr: &SocketAddr) {
self.clients.write().await.remove(addr);
}
pub async fn broadcast(&self, message: String) {
let _ = self.broadcast_tx.send(message.clone());
}
pub async fn client_count(&self) -> usize {
self.clients.read().await.len()
}
fn subscribe(&self) -> broadcast::Receiver<String> {
self.broadcast_tx.subscribe()
}
}
#[cfg(feature = "websocket")]
#[async_trait::async_trait]
pub trait WebSocketHandler: Send + Sync {
async fn handle_message(&self, message: String) -> Result<String, String>;
async fn on_connect(&self) {}
async fn on_disconnect(&self) {}
}
#[cfg(feature = "websocket")]
pub struct WebSocketServer {
handler: Arc<dyn WebSocketHandler>,
pub broadcast_manager: Option<BroadcastManager>,
}
#[cfg(feature = "websocket")]
impl WebSocketServer {
pub fn new<H: WebSocketHandler + 'static>(handler: H) -> Self {
Self {
handler: Arc::new(handler),
broadcast_manager: None,
}
}
pub fn with_broadcast(mut self, capacity: usize) -> Self {
self.broadcast_manager = Some(BroadcastManager::new(capacity));
self
}
pub fn from_arc(handler: Arc<dyn WebSocketHandler>) -> Self {
Self {
handler,
broadcast_manager: None,
}
}
pub fn broadcast_manager(&self) -> Option<&BroadcastManager> {
self.broadcast_manager.as_ref()
}
pub async fn listen(self, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
println!("WebSocket server listening on ws://{}", addr);
let broadcast_manager = self.broadcast_manager.clone();
loop {
let (stream, peer_addr) = listener.accept().await?;
let handler = self.handler.clone();
let manager = broadcast_manager.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, handler, peer_addr, manager).await {
eprintln!("Error handling WebSocket connection: {:?}", e);
}
});
}
}
pub async fn listen_with_shutdown(
self,
addr: SocketAddr,
coordinator: ShutdownCoordinator,
) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
println!("WebSocket server listening on ws://{}", addr);
let broadcast_manager = self.broadcast_manager.clone();
let mut shutdown_rx = coordinator.subscribe();
loop {
tokio::select! {
result = listener.accept() => {
let (stream, peer_addr) = result?;
let handler = self.handler.clone();
let manager = broadcast_manager.clone();
let mut conn_shutdown = coordinator.subscribe();
tokio::spawn(async move {
tokio::select! {
result = Self::handle_connection(stream, handler, peer_addr, manager) => {
if let Err(e) = result {
eprintln!("Error handling WebSocket connection: {:?}", e);
}
}
_ = conn_shutdown.recv() => {
}
}
});
}
_ = shutdown_rx.recv() => {
println!("Shutdown signal received, stopping WebSocket server...");
break;
}
}
}
coordinator.notify_shutdown_complete();
Ok(())
}
pub async fn handle_connection(
stream: TcpStream,
handler: Arc<dyn WebSocketHandler>,
peer_addr: SocketAddr,
broadcast_manager: Option<BroadcastManager>,
) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(debug_assertions)]
eprintln!("[ws:debug] new connection from: {}", peer_addr);
let ws_stream = accept_async(stream).await?;
let (write, mut read) = ws_stream.split();
let mut direct_write = if let Some(ref manager) = broadcast_manager {
manager.register_client(peer_addr, write).await;
None
} else {
Some(write)
};
let use_broadcast = broadcast_manager.is_some();
handler.on_connect().await;
let mut broadcast_rx = broadcast_manager.as_ref().map(|m| m.subscribe());
loop {
tokio::select! {
message = read.next() => {
match message {
Some(Ok(msg)) => {
if msg.is_text() {
let text = msg.to_text()?;
#[cfg(debug_assertions)]
eprintln!("[ws:trace] text message from peer ({} bytes)", text.len());
match handler.handle_message(text.to_string()).await {
Ok(response) => {
if use_broadcast {
if let Some(ref manager) = broadcast_manager
&& let Some(clients) = manager.clients.read().await.get(&peer_addr) {
let mut sender = clients.sender.lock().await;
sender.send(Message::Text(response.into())).await?;
}
} else if let Some(ref mut w) = direct_write {
w.send(Message::Text(response.into())).await?;
}
}
Err(error) => {
if use_broadcast {
if let Some(ref manager) = broadcast_manager
&& let Some(clients) = manager.clients.read().await.get(&peer_addr) {
let mut sender = clients.sender.lock().await;
sender.send(Message::Text(error.into())).await?;
}
} else if let Some(ref mut w) = direct_write {
w.send(Message::Text(error.into())).await?;
}
}
}
} else if msg.is_binary() {
let data = msg.into_data();
#[cfg(debug_assertions)]
eprintln!("[ws:trace] binary message from peer ({} bytes)", data.len());
if use_broadcast {
if let Some(ref manager) = broadcast_manager
&& let Some(clients) = manager.clients.read().await.get(&peer_addr) {
let mut sender = clients.sender.lock().await;
sender.send(Message::Binary(data)).await?;
}
} else if let Some(ref mut w) = direct_write {
w.send(Message::Binary(data)).await?;
}
} else if msg.is_close() {
#[cfg(debug_assertions)]
eprintln!("[ws:debug] connection closing for peer");
if use_broadcast {
if let Some(ref manager) = broadcast_manager
&& let Some(clients) = manager.clients.read().await.get(&peer_addr) {
let mut sender = clients.sender.lock().await;
let _ = sender.send(Message::Close(None)).await;
let _ = sender.flush().await;
}
} else if let Some(ref mut w) = direct_write {
let _ = w.send(Message::Close(None)).await;
let _ = w.flush().await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
break;
}
}
Some(Err(e)) => {
eprintln!("WebSocket error: {}", e);
break;
}
None => break,
}
}
broadcast_msg = async {
match &mut broadcast_rx {
Some(rx) => rx.recv().await.ok(),
None => std::future::pending().await,
}
} => {
if let Some(msg) = broadcast_msg
&& let Some(ref manager) = broadcast_manager
&& let Some(client) = manager.clients.read().await.get(&peer_addr) {
let mut sender = client.sender.lock().await;
if let Err(e) = sender.send(Message::Text(msg.into())).await {
eprintln!("Failed to send broadcast message: {}", e);
break;
}
}
}
}
}
if let Some(ref manager) = broadcast_manager {
manager.unregister_client(&peer_addr).await;
}
handler.on_disconnect().await;
#[cfg(debug_assertions)]
eprintln!("[ws:debug] connection closed for peer");
Ok(())
}
}
#[cfg(feature = "websocket")]
pub async fn serve_websocket<H: WebSocketHandler + 'static>(
addr: SocketAddr,
handler: H,
) -> Result<(), Box<dyn std::error::Error>> {
let server = WebSocketServer::new(handler);
server.listen(addr).await
}
#[cfg(all(test, feature = "websocket"))]
mod tests {
use super::*;
struct EchoHandler;
#[async_trait::async_trait]
impl WebSocketHandler for EchoHandler {
async fn handle_message(&self, message: String) -> Result<String, String> {
Ok(format!("Echo: {}", message))
}
}
#[tokio::test]
async fn test_websocket_server_creation() {
let _server = WebSocketServer::new(EchoHandler);
}
#[tokio::test]
async fn test_websocket_server_with_broadcast() {
let _server = WebSocketServer::new(EchoHandler).with_broadcast(100);
}
#[tokio::test]
async fn test_broadcast_manager_creation() {
let manager = BroadcastManager::new(50);
assert_eq!(manager.client_count().await, 0);
}
#[tokio::test]
async fn test_broadcast_manager_broadcast() {
let manager = BroadcastManager::new(50);
let mut rx = manager.subscribe();
manager.broadcast("Hello!".to_string()).await;
let received = rx.recv().await.unwrap();
assert_eq!(received, "Hello!");
}
#[tokio::test]
async fn test_broadcast_manager_multiple_subscribers() {
let manager = BroadcastManager::new(50);
let mut rx1 = manager.subscribe();
let mut rx2 = manager.subscribe();
let mut rx3 = manager.subscribe();
manager.broadcast("Broadcast message".to_string()).await;
assert_eq!(rx1.recv().await.unwrap(), "Broadcast message");
assert_eq!(rx2.recv().await.unwrap(), "Broadcast message");
assert_eq!(rx3.recv().await.unwrap(), "Broadcast message");
}
}