use super::connection::WebSocketConnection;
use super::registry::{ConnectionRegistry, RegistryStats};
use super::types::{ConnectionId, WebSocketConfig, WebSocketMessage, WebSocketResult};
use crate::routing::ElifRouter;
use axum::{extract::ws::WebSocketUpgrade as AxumWebSocketUpgrade, routing::get};
use std::sync::Arc;
use tokio::time::{interval, Duration};
use tracing::{debug, info};
pub struct WebSocketServer {
registry: Arc<ConnectionRegistry>,
_config: WebSocketConfig,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
impl WebSocketServer {
pub fn new() -> Self {
Self {
registry: Arc::new(ConnectionRegistry::new()),
_config: WebSocketConfig::default(),
cleanup_handle: None,
}
}
pub fn with_config(config: WebSocketConfig) -> Self {
Self {
registry: Arc::new(ConnectionRegistry::new()),
_config: config,
cleanup_handle: None,
}
}
pub fn registry(&self) -> Arc<ConnectionRegistry> {
self.registry.clone()
}
pub async fn stats(&self) -> RegistryStats {
self.registry.stats().await
}
pub fn add_websocket_route<F, Fut>(
&self,
router: ElifRouter,
path: &str,
_handler: F,
) -> ElifRouter
where
F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let ws_handler = move |ws: AxumWebSocketUpgrade| async move {
ws.on_upgrade(|mut socket| async move {
tracing::info!("WebSocket connection established");
while let Some(_msg) = socket.recv().await {
if let Ok(_) = socket
.send(axum::extract::ws::Message::Text("pong".to_string()))
.await
{
continue;
}
break;
}
tracing::info!("WebSocket connection closed");
})
};
router.add_axum_route(path, get(ws_handler))
}
pub fn add_handler<F, Fut>(&self, router: ElifRouter, path: &str, handler: F) -> ElifRouter
where
F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
self.add_websocket_route(router, path, handler)
}
pub async fn broadcast(&self, message: WebSocketMessage) -> super::registry::BroadcastResult {
self.registry.broadcast(message).await
}
pub async fn broadcast_text<T: Into<String>>(
&self,
text: T,
) -> super::registry::BroadcastResult {
self.registry.broadcast_text(text).await
}
pub async fn broadcast_binary<T: Into<Vec<u8>>>(
&self,
data: T,
) -> super::registry::BroadcastResult {
self.registry.broadcast_binary(data).await
}
pub async fn send_to_connection(
&self,
id: ConnectionId,
message: WebSocketMessage,
) -> WebSocketResult<()> {
self.registry.send_to_connection(id, message).await
}
pub async fn send_text_to_connection<T: Into<String>>(
&self,
id: ConnectionId,
text: T,
) -> WebSocketResult<()> {
self.registry.send_text_to_connection(id, text).await
}
pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
&self,
id: ConnectionId,
data: T,
) -> WebSocketResult<()> {
self.registry.send_binary_to_connection(id, data).await
}
pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
self.registry.get_connection_ids().await
}
pub async fn connection_count(&self) -> usize {
self.registry.connection_count().await
}
pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
self.registry.close_connection(id).await
}
pub async fn close_all_connections(&self) -> super::registry::CloseAllResult {
self.registry.close_all_connections().await
}
pub fn start_cleanup_task(&mut self, interval_seconds: u64) {
if self.cleanup_handle.is_some() {
debug!("Cleanup task already running");
return;
}
let registry = self.registry.clone();
let handle = tokio::spawn(async move {
let mut cleanup_interval = interval(Duration::from_secs(interval_seconds));
loop {
cleanup_interval.tick().await;
let cleaned = registry.cleanup_inactive_connections().await;
if cleaned > 0 {
debug!("Cleanup task removed {} inactive connections", cleaned);
}
}
});
self.cleanup_handle = Some(handle);
info!(
"Started WebSocket cleanup task with {}s interval",
interval_seconds
);
}
pub fn stop_cleanup_task(&mut self) {
if let Some(handle) = self.cleanup_handle.take() {
handle.abort();
info!("Stopped WebSocket cleanup task");
}
}
}
impl Default for WebSocketServer {
fn default() -> Self {
Self::new()
}
}
impl Drop for WebSocketServer {
fn drop(&mut self) {
self.stop_cleanup_task();
}
}
#[derive(Debug)]
pub struct WebSocketServerBuilder {
_config: WebSocketConfig,
cleanup_interval: Option<u64>,
}
impl WebSocketServerBuilder {
pub fn new() -> Self {
Self {
_config: WebSocketConfig::default(),
cleanup_interval: Some(300), }
}
pub fn max_message_size(mut self, size: usize) -> Self {
self._config.max_message_size = Some(size);
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self._config.max_frame_size = Some(size);
self
}
pub fn auto_pong(mut self, enabled: bool) -> Self {
self._config.auto_pong = enabled;
self
}
pub fn ping_interval(mut self, seconds: u64) -> Self {
self._config.ping_interval = Some(seconds);
self
}
pub fn connect_timeout(mut self, seconds: u64) -> Self {
self._config.connect_timeout = Some(seconds);
self
}
pub fn cleanup_interval(mut self, seconds: u64) -> Self {
self.cleanup_interval = Some(seconds);
self
}
pub fn no_cleanup(mut self) -> Self {
self.cleanup_interval = None;
self
}
pub fn build(self) -> WebSocketServer {
let mut server = WebSocketServer::with_config(self._config);
if let Some(interval) = self.cleanup_interval {
server.start_cleanup_task(interval);
}
server
}
}
impl Default for WebSocketServerBuilder {
fn default() -> Self {
Self::new()
}
}