elif_http/websocket/
server.rs1use super::connection::WebSocketConnection;
4use super::registry::{ConnectionRegistry, RegistryStats};
5use super::types::{ConnectionId, WebSocketConfig, WebSocketMessage, WebSocketResult};
6use crate::routing::ElifRouter;
7use axum::{extract::ws::WebSocketUpgrade as AxumWebSocketUpgrade, routing::get};
8use std::sync::Arc;
9use tokio::time::{interval, Duration};
10use tracing::{debug, info};
11
12pub struct WebSocketServer {
14 registry: Arc<ConnectionRegistry>,
16 _config: WebSocketConfig,
18 cleanup_handle: Option<tokio::task::JoinHandle<()>>,
20}
21
22impl WebSocketServer {
23 pub fn new() -> Self {
25 Self {
26 registry: Arc::new(ConnectionRegistry::new()),
27 _config: WebSocketConfig::default(),
28 cleanup_handle: None,
29 }
30 }
31
32 pub fn with_config(config: WebSocketConfig) -> Self {
34 Self {
35 registry: Arc::new(ConnectionRegistry::new()),
36 _config: config,
37 cleanup_handle: None,
38 }
39 }
40
41 pub fn registry(&self) -> Arc<ConnectionRegistry> {
43 self.registry.clone()
44 }
45
46 pub async fn stats(&self) -> RegistryStats {
48 self.registry.stats().await
49 }
50
51 pub fn add_websocket_route<F, Fut>(
54 &self,
55 router: ElifRouter,
56 path: &str,
57 _handler: F,
58 ) -> ElifRouter
59 where
60 F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
61 Fut: std::future::Future<Output = ()> + Send + 'static,
62 {
63 let ws_handler = move |ws: AxumWebSocketUpgrade| async move {
68 ws.on_upgrade(|mut socket| async move {
69 tracing::info!("WebSocket connection established");
70 while let Some(_msg) = socket.recv().await {
72 if let Ok(_) = socket
74 .send(axum::extract::ws::Message::Text("pong".to_string()))
75 .await
76 {
77 continue;
78 }
79 break;
80 }
81 tracing::info!("WebSocket connection closed");
82 })
83 };
84
85 router.add_axum_route(path, get(ws_handler))
87 }
88
89 pub fn add_handler<F, Fut>(&self, router: ElifRouter, path: &str, handler: F) -> ElifRouter
91 where
92 F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
93 Fut: std::future::Future<Output = ()> + Send + 'static,
94 {
95 self.add_websocket_route(router, path, handler)
96 }
97
98 pub async fn broadcast(&self, message: WebSocketMessage) -> super::registry::BroadcastResult {
100 self.registry.broadcast(message).await
101 }
102
103 pub async fn broadcast_text<T: Into<String>>(
105 &self,
106 text: T,
107 ) -> super::registry::BroadcastResult {
108 self.registry.broadcast_text(text).await
109 }
110
111 pub async fn broadcast_binary<T: Into<Vec<u8>>>(
113 &self,
114 data: T,
115 ) -> super::registry::BroadcastResult {
116 self.registry.broadcast_binary(data).await
117 }
118
119 pub async fn send_to_connection(
121 &self,
122 id: ConnectionId,
123 message: WebSocketMessage,
124 ) -> WebSocketResult<()> {
125 self.registry.send_to_connection(id, message).await
126 }
127
128 pub async fn send_text_to_connection<T: Into<String>>(
130 &self,
131 id: ConnectionId,
132 text: T,
133 ) -> WebSocketResult<()> {
134 self.registry.send_text_to_connection(id, text).await
135 }
136
137 pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
139 &self,
140 id: ConnectionId,
141 data: T,
142 ) -> WebSocketResult<()> {
143 self.registry.send_binary_to_connection(id, data).await
144 }
145
146 pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
148 self.registry.get_connection_ids().await
149 }
150
151 pub async fn connection_count(&self) -> usize {
153 self.registry.connection_count().await
154 }
155
156 pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
158 self.registry.close_connection(id).await
159 }
160
161 pub async fn close_all_connections(&self) -> super::registry::CloseAllResult {
163 self.registry.close_all_connections().await
164 }
165
166 pub fn start_cleanup_task(&mut self, interval_seconds: u64) {
168 if self.cleanup_handle.is_some() {
169 debug!("Cleanup task already running");
170 return;
171 }
172
173 let registry = self.registry.clone();
174 let handle = tokio::spawn(async move {
175 let mut cleanup_interval = interval(Duration::from_secs(interval_seconds));
176
177 loop {
178 cleanup_interval.tick().await;
179 let cleaned = registry.cleanup_inactive_connections().await;
180 if cleaned > 0 {
181 debug!("Cleanup task removed {} inactive connections", cleaned);
182 }
183 }
184 });
185
186 self.cleanup_handle = Some(handle);
187 info!(
188 "Started WebSocket cleanup task with {}s interval",
189 interval_seconds
190 );
191 }
192
193 pub fn stop_cleanup_task(&mut self) {
195 if let Some(handle) = self.cleanup_handle.take() {
196 handle.abort();
197 info!("Stopped WebSocket cleanup task");
198 }
199 }
200}
201
202impl Default for WebSocketServer {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl Drop for WebSocketServer {
209 fn drop(&mut self) {
210 self.stop_cleanup_task();
211 }
212}
213
214#[derive(Debug)]
216pub struct WebSocketServerBuilder {
217 _config: WebSocketConfig,
218 cleanup_interval: Option<u64>,
219}
220
221impl WebSocketServerBuilder {
222 pub fn new() -> Self {
224 Self {
225 _config: WebSocketConfig::default(),
226 cleanup_interval: Some(300), }
228 }
229
230 pub fn max_message_size(mut self, size: usize) -> Self {
232 self._config.max_message_size = Some(size);
233 self
234 }
235
236 pub fn max_frame_size(mut self, size: usize) -> Self {
238 self._config.max_frame_size = Some(size);
239 self
240 }
241
242 pub fn auto_pong(mut self, enabled: bool) -> Self {
244 self._config.auto_pong = enabled;
245 self
246 }
247
248 pub fn ping_interval(mut self, seconds: u64) -> Self {
250 self._config.ping_interval = Some(seconds);
251 self
252 }
253
254 pub fn connect_timeout(mut self, seconds: u64) -> Self {
256 self._config.connect_timeout = Some(seconds);
257 self
258 }
259
260 pub fn cleanup_interval(mut self, seconds: u64) -> Self {
262 self.cleanup_interval = Some(seconds);
263 self
264 }
265
266 pub fn no_cleanup(mut self) -> Self {
268 self.cleanup_interval = None;
269 self
270 }
271
272 pub fn build(self) -> WebSocketServer {
274 let mut server = WebSocketServer::with_config(self._config);
275
276 if let Some(interval) = self.cleanup_interval {
277 server.start_cleanup_task(interval);
278 }
279
280 server
281 }
282}
283
284impl Default for WebSocketServerBuilder {
285 fn default() -> Self {
286 Self::new()
287 }
288}