elif_http/websocket/
server.rs1use super::registry::{ConnectionRegistry, RegistryStats};
4use super::types::{ConnectionId, WebSocketConfig, WebSocketMessage, WebSocketResult};
5use super::connection::WebSocketConnection;
6use crate::routing::ElifRouter;
7use axum::{
8 extract::ws::WebSocketUpgrade as AxumWebSocketUpgrade,
9 routing::get,
10};
11use std::sync::Arc;
12use tokio::time::{interval, Duration};
13use tracing::{debug, info};
14
15pub struct WebSocketServer {
17 registry: Arc<ConnectionRegistry>,
19 _config: WebSocketConfig,
21 cleanup_handle: Option<tokio::task::JoinHandle<()>>,
23}
24
25impl WebSocketServer {
26 pub fn new() -> Self {
28 Self {
29 registry: Arc::new(ConnectionRegistry::new()),
30 _config: WebSocketConfig::default(),
31 cleanup_handle: None,
32 }
33 }
34
35 pub fn with_config(config: WebSocketConfig) -> Self {
37 Self {
38 registry: Arc::new(ConnectionRegistry::new()),
39 _config: config,
40 cleanup_handle: None,
41 }
42 }
43
44 pub fn registry(&self) -> Arc<ConnectionRegistry> {
46 self.registry.clone()
47 }
48
49 pub async fn stats(&self) -> RegistryStats {
51 self.registry.stats().await
52 }
53
54 pub fn add_websocket_route<F, Fut>(
57 &self,
58 router: ElifRouter,
59 path: &str,
60 _handler: F,
61 ) -> ElifRouter
62 where
63 F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
64 Fut: std::future::Future<Output = ()> + Send + 'static,
65 {
66 let ws_handler = move |ws: AxumWebSocketUpgrade| async move {
71 ws.on_upgrade(|mut socket| async move {
72 tracing::info!("WebSocket connection established");
73 while let Some(_msg) = socket.recv().await {
75 if let Ok(_) = socket.send(axum::extract::ws::Message::Text("pong".to_string())).await {
77 continue;
78 }
79 break;
80 }
81 tracing::info!("WebSocket connection closed");
82 })
83 };
84
85 router.add_axum_route(path, get(ws_handler))
87 }
88
89 pub fn add_handler<F, Fut>(
91 &self,
92 router: ElifRouter,
93 path: &str,
94 handler: F,
95 ) -> ElifRouter
96 where
97 F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + Clone + 'static,
98 Fut: std::future::Future<Output = ()> + Send + 'static,
99 {
100 self.add_websocket_route(router, path, handler)
101 }
102
103 pub async fn broadcast(&self, message: WebSocketMessage) -> super::registry::BroadcastResult {
105 self.registry.broadcast(message).await
106 }
107
108 pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> super::registry::BroadcastResult {
110 self.registry.broadcast_text(text).await
111 }
112
113 pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> super::registry::BroadcastResult {
115 self.registry.broadcast_binary(data).await
116 }
117
118 pub async fn send_to_connection(
120 &self,
121 id: ConnectionId,
122 message: WebSocketMessage,
123 ) -> WebSocketResult<()> {
124 self.registry.send_to_connection(id, message).await
125 }
126
127 pub async fn send_text_to_connection<T: Into<String>>(
129 &self,
130 id: ConnectionId,
131 text: T,
132 ) -> WebSocketResult<()> {
133 self.registry.send_text_to_connection(id, text).await
134 }
135
136 pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
138 &self,
139 id: ConnectionId,
140 data: T,
141 ) -> WebSocketResult<()> {
142 self.registry.send_binary_to_connection(id, data).await
143 }
144
145 pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
147 self.registry.get_connection_ids().await
148 }
149
150 pub async fn connection_count(&self) -> usize {
152 self.registry.connection_count().await
153 }
154
155 pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
157 self.registry.close_connection(id).await
158 }
159
160 pub async fn close_all_connections(&self) -> super::registry::CloseAllResult {
162 self.registry.close_all_connections().await
163 }
164
165 pub fn start_cleanup_task(&mut self, interval_seconds: u64) {
167 if self.cleanup_handle.is_some() {
168 debug!("Cleanup task already running");
169 return;
170 }
171
172 let registry = self.registry.clone();
173 let handle = tokio::spawn(async move {
174 let mut cleanup_interval = interval(Duration::from_secs(interval_seconds));
175
176 loop {
177 cleanup_interval.tick().await;
178 let cleaned = registry.cleanup_inactive_connections().await;
179 if cleaned > 0 {
180 debug!("Cleanup task removed {} inactive connections", cleaned);
181 }
182 }
183 });
184
185 self.cleanup_handle = Some(handle);
186 info!("Started WebSocket cleanup task with {}s interval", interval_seconds);
187 }
188
189 pub fn stop_cleanup_task(&mut self) {
191 if let Some(handle) = self.cleanup_handle.take() {
192 handle.abort();
193 info!("Stopped WebSocket cleanup task");
194 }
195 }
196}
197
198impl Default for WebSocketServer {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl Drop for WebSocketServer {
205 fn drop(&mut self) {
206 self.stop_cleanup_task();
207 }
208}
209
210#[derive(Debug)]
212pub struct WebSocketServerBuilder {
213 _config: WebSocketConfig,
214 cleanup_interval: Option<u64>,
215}
216
217impl WebSocketServerBuilder {
218 pub fn new() -> Self {
220 Self {
221 _config: WebSocketConfig::default(),
222 cleanup_interval: Some(300), }
224 }
225
226 pub fn max_message_size(mut self, size: usize) -> Self {
228 self._config.max_message_size = Some(size);
229 self
230 }
231
232 pub fn max_frame_size(mut self, size: usize) -> Self {
234 self._config.max_frame_size = Some(size);
235 self
236 }
237
238 pub fn auto_pong(mut self, enabled: bool) -> Self {
240 self._config.auto_pong = enabled;
241 self
242 }
243
244 pub fn ping_interval(mut self, seconds: u64) -> Self {
246 self._config.ping_interval = Some(seconds);
247 self
248 }
249
250 pub fn connect_timeout(mut self, seconds: u64) -> Self {
252 self._config.connect_timeout = Some(seconds);
253 self
254 }
255
256 pub fn cleanup_interval(mut self, seconds: u64) -> Self {
258 self.cleanup_interval = Some(seconds);
259 self
260 }
261
262 pub fn no_cleanup(mut self) -> Self {
264 self.cleanup_interval = None;
265 self
266 }
267
268 pub fn build(self) -> WebSocketServer {
270 let mut server = WebSocketServer::with_config(self._config);
271
272 if let Some(interval) = self.cleanup_interval {
273 server.start_cleanup_task(interval);
274 }
275
276 server
277 }
278}
279
280impl Default for WebSocketServerBuilder {
281 fn default() -> Self {
282 Self::new()
283 }
284}