use crate::connection::{Connection, ConnectionWriter};
use crate::error::{WebSocketError, WebSocketResult};
use crate::handler::WebSocketHandler;
use crate::message::Message;
use crate::room::RoomManager;
use futures_util::StreamExt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_tungstenite::accept_async;
#[derive(Debug, Clone)]
pub struct WebSocketServerConfig {
pub bind_addr: SocketAddr,
pub max_message_size: usize,
pub heartbeat_interval: Duration,
pub connection_timeout: Duration,
}
impl Default for WebSocketServerConfig {
fn default() -> Self {
Self {
bind_addr: "0.0.0.0:9001".parse().unwrap(),
max_message_size: 64 * 1024, heartbeat_interval: Duration::from_secs(30),
connection_timeout: Duration::from_secs(60),
}
}
}
#[derive(Debug, Default)]
pub struct WebSocketServerBuilder {
config: WebSocketServerConfig,
}
impl WebSocketServerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.config.bind_addr = addr;
self
}
pub fn bind(mut self, addr: &str) -> WebSocketResult<Self> {
self.config.bind_addr = addr
.parse()
.map_err(|e| WebSocketError::Server(format!("Invalid address: {}", e)))?;
Ok(self)
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.config.max_message_size = size;
self
}
pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
self.config.heartbeat_interval = interval;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.config.connection_timeout = timeout;
self
}
pub fn build<H: WebSocketHandler>(self, handler: H) -> WebSocketServer<H> {
WebSocketServer::new(self.config, handler)
}
}
pub struct WebSocketServer<H: WebSocketHandler> {
config: WebSocketServerConfig,
handler: Arc<H>,
room_manager: Arc<RoomManager>,
}
impl<H: WebSocketHandler> WebSocketServer<H> {
pub fn new(config: WebSocketServerConfig, handler: H) -> Self {
Self {
config,
handler: Arc::new(handler),
room_manager: Arc::new(RoomManager::new()),
}
}
pub fn builder() -> WebSocketServerBuilder {
WebSocketServerBuilder::new()
}
pub fn room_manager(&self) -> &Arc<RoomManager> {
&self.room_manager
}
pub async fn run(&self) -> WebSocketResult<()> {
let listener = TcpListener::bind(self.config.bind_addr).await?;
tracing::info!(addr = %self.config.bind_addr, "WebSocket server listening");
loop {
match listener.accept().await {
Ok((stream, addr)) => {
let handler = Arc::clone(&self.handler);
let room_manager = Arc::clone(&self.room_manager);
let config = self.config.clone();
tokio::spawn(async move {
if let Err(e) =
Self::handle_connection(stream, addr, handler, room_manager, config)
.await
{
tracing::error!(addr = %addr, error = %e, "Connection error");
}
});
}
Err(e) => {
tracing::error!(error = %e, "Failed to accept connection");
}
}
}
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
handler: Arc<H>,
room_manager: Arc<RoomManager>,
_config: WebSocketServerConfig,
) -> WebSocketResult<()> {
let ws_stream = accept_async(stream).await?;
let connection_id = uuid::Uuid::new_v4().to_string();
tracing::debug!(connection_id = %connection_id, addr = %addr, "WebSocket connection established");
let (write, mut read) = ws_stream.split();
let (tx, rx) = mpsc::unbounded_channel();
let connection = Connection::new(connection_id.clone(), Some(addr), tx);
room_manager.register_connection(connection.clone());
handler.on_connect(&connection_id).await;
let writer = ConnectionWriter::new(write, rx);
let writer_handle = tokio::spawn(async move { writer.run().await });
while let Some(result) = read.next().await {
match result {
Ok(msg) => {
if msg.is_close() {
break;
}
let message: Message = msg.into();
if message.is_ping() {
let pong_payload = handler.on_ping(&connection_id, message.as_bytes()).await;
let _ = connection.send(Message::pong(pong_payload));
continue;
}
if message.is_pong() {
handler.on_pong(&connection_id, message.as_bytes()).await;
continue;
}
handler.on_message(&connection_id, message).await;
}
Err(e) => {
let ws_error = WebSocketError::Protocol(e);
handler.on_error(&connection_id, &ws_error).await;
break;
}
}
}
connection.close();
let _ = writer_handle.await;
handler.on_disconnect(&connection_id).await;
room_manager.unregister_connection(&connection_id);
tracing::debug!(connection_id = %connection_id, "WebSocket connection closed");
Ok(())
}
}