1use crate::config::WebSocketConfig;
44use futures_util::{SinkExt, StreamExt};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47use tokio::sync::mpsc;
48use tracing::{debug, error, info, warn};
49use uuid::Uuid;
50use warp::ws::Message;
51
52#[derive(Debug)]
54pub struct WebSocketState {
55 config: WebSocketConfig,
56 connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
57 broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
58 broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
59}
60
61impl WebSocketState {
62 pub fn new(config: WebSocketConfig) -> Self {
63 let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
64
65 Self {
66 config,
67 connections: HashMap::new(),
68 broadcast_sender,
69 broadcast_receiver: Some(broadcast_receiver),
70 }
71 }
72
73 pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
75 let connection_id = Uuid::new_v4();
76
77 if self.connections.len() >= self.config.max_connections {
78 warn!("Maximum WebSocket connections reached, rejecting new connection");
79 return Ok(());
80 }
81
82 let (mut ws_sender, mut ws_receiver) = websocket.split();
83 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
84
85 self.connections.insert(connection_id, tx);
87 info!("WebSocket connection established: {}", connection_id);
88
89 let connection_id_clone = connection_id;
91 tokio::spawn(async move {
92 while let Some(message) = rx.recv().await {
93 if let Err(e) = ws_sender.send(message).await {
94 debug!(
95 "Failed to send WebSocket message to {}: {}",
96 connection_id_clone, e
97 );
98 break;
99 }
100 }
101 });
102
103 let broadcast_sender = self.broadcast_sender.clone();
105
106 while let Some(result) = ws_receiver.next().await {
107 match result {
108 Ok(message) => {
109 if let Err(e) = self
110 .handle_client_message(connection_id, message, &broadcast_sender)
111 .await
112 {
113 error!(
114 "Error handling client message from {}: {}",
115 connection_id, e
116 );
117 break;
118 }
119 }
120 Err(e) => {
121 debug!("WebSocket error for connection {}: {}", connection_id, e);
122 break;
123 }
124 }
125 }
126
127 self.connections.remove(&connection_id);
129 info!("WebSocket connection closed: {}", connection_id);
130
131 Ok(())
132 }
133
134 async fn handle_client_message(
136 &self,
137 connection_id: Uuid,
138 message: Message,
139 _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
140 ) -> crate::Result<()> {
141 if message.is_text() {
142 if let Ok(text) = message.to_str() {
143 if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
144 debug!(
145 "Received message from {}: {:?}",
146 connection_id, client_message
147 );
148 self.handle_client_action(connection_id, client_message)
149 .await?;
150 } else {
151 warn!("Invalid message format from {}: {}", connection_id, text);
152 }
153 }
154 } else if message.is_ping() {
155 if let Some(sender) = self.connections.get(&connection_id) {
157 let pong_msg = Message::pong(message.as_bytes());
158 let _ = sender.send(pong_msg);
159 }
160 } else if message.is_pong() {
161 debug!("Pong received from {}", connection_id);
163 } else if message.is_close() {
164 debug!("Close message received from {}", connection_id);
165 } else if message.is_binary() {
166 warn!("Binary message not supported from {}", connection_id);
167 }
168
169 Ok(())
170 }
171
172 async fn handle_client_action(
174 &self,
175 _connection_id: Uuid,
176 message: ClientMessage,
177 ) -> crate::Result<()> {
178 match message {
179 ClientMessage::Subscribe { event_types } => {
180 info!("Client subscribed to events: {:?}", event_types);
181 }
183 ClientMessage::Unsubscribe { event_types } => {
184 info!("Client unsubscribed from events: {:?}", event_types);
185 }
187 ClientMessage::Ping => {
188 self.broadcast_to_all(ServerMessage::Pong).await?;
190 }
191 }
192
193 Ok(())
194 }
195
196 pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
198 let json_message = serde_json::to_string(&message)?;
199 let ws_message = Message::text(json_message);
200
201 let mut disconnected = Vec::new();
202
203 for (&connection_id, sender) in &self.connections {
204 if sender.send(ws_message.clone()).is_err() {
205 disconnected.push(connection_id);
206 }
207 }
208
209 Ok(())
214 }
215
216 pub async fn ping_all_connections(&self) {
218 let ping_message = Message::ping(b"ping");
219 let mut disconnected = Vec::new();
220
221 for (&connection_id, sender) in &self.connections {
222 if sender.send(ping_message.clone()).is_err() {
223 disconnected.push(connection_id);
224 }
225 }
226
227 if !disconnected.is_empty() {
228 debug!(
229 "Detected {} disconnected WebSocket clients during ping",
230 disconnected.len()
231 );
232 }
233 }
234
235 pub fn connection_count(&self) -> usize {
237 self.connections.len()
238 }
239
240 pub async fn start_broadcast_listener(&mut self) -> crate::Result<()> {
242 if let Some(mut receiver) = self.broadcast_receiver.take() {
243 tokio::spawn(async move {
244 while let Some(broadcast_message) = receiver.recv().await {
245 let server_message = match broadcast_message {
247 BroadcastMessage::QueueUpdate { queue_name, stats } => {
248 ServerMessage::QueueUpdate { queue_name, stats }
249 }
250 BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
251 BroadcastMessage::SystemAlert { message, severity } => {
252 ServerMessage::SystemAlert { message, severity }
253 }
254 };
255
256 debug!("Broadcasting message: {:?}", server_message);
259 }
260 });
261 }
262 Ok(())
263 }
264}
265
266#[derive(Debug, Deserialize)]
268#[serde(tag = "type")]
269pub enum ClientMessage {
270 Subscribe { event_types: Vec<String> },
271 Unsubscribe { event_types: Vec<String> },
272 Ping,
273}
274
275#[derive(Debug, Serialize)]
277#[serde(tag = "type")]
278pub enum ServerMessage {
279 QueueUpdate {
280 queue_name: String,
281 stats: QueueStats,
282 },
283 JobUpdate {
284 job: JobUpdate,
285 },
286 SystemAlert {
287 message: String,
288 severity: AlertSeverity,
289 },
290 Pong,
291}
292
293#[derive(Debug)]
295pub enum BroadcastMessage {
296 QueueUpdate {
297 queue_name: String,
298 stats: QueueStats,
299 },
300 JobUpdate {
301 job: JobUpdate,
302 },
303 SystemAlert {
304 message: String,
305 severity: AlertSeverity,
306 },
307}
308
309#[derive(Debug, Serialize)]
311pub struct QueueStats {
312 pub pending_count: u64,
313 pub running_count: u64,
314 pub completed_count: u64,
315 pub failed_count: u64,
316 pub dead_count: u64,
317 pub throughput_per_minute: f64,
318 pub avg_processing_time_ms: f64,
319 pub error_rate: f64,
320 pub updated_at: chrono::DateTime<chrono::Utc>,
321}
322
323#[derive(Debug, Serialize)]
325pub struct JobUpdate {
326 pub id: String,
327 pub queue_name: String,
328 pub status: String,
329 pub priority: String,
330 pub attempts: i32,
331 pub updated_at: chrono::DateTime<chrono::Utc>,
332}
333
334#[derive(Debug, Serialize)]
336pub enum AlertSeverity {
337 Info,
338 Warning,
339 Error,
340 Critical,
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::config::WebSocketConfig;
347
348 #[test]
349 fn test_websocket_state_creation() {
350 let config = WebSocketConfig::default();
351 let state = WebSocketState::new(config);
352 assert_eq!(state.connection_count(), 0);
353 }
354
355 #[test]
356 fn test_client_message_deserialization() {
357 let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
358 let message: ClientMessage = serde_json::from_str(json).unwrap();
359
360 match message {
361 ClientMessage::Subscribe { event_types } => {
362 assert_eq!(event_types.len(), 2);
363 assert!(event_types.contains(&"queue_updates".to_string()));
364 }
365 _ => panic!("Wrong message type"),
366 }
367 }
368
369 #[test]
370 fn test_server_message_serialization() {
371 let message = ServerMessage::SystemAlert {
372 message: "High error rate detected".to_string(),
373 severity: AlertSeverity::Warning,
374 };
375
376 let json = serde_json::to_string(&message).unwrap();
377 assert!(json.contains("type"));
378 assert!(json.contains("SystemAlert"));
379 assert!(json.contains("High error rate detected"));
380 }
381
382 #[tokio::test]
383 async fn test_broadcast_to_all() {
384 let config = WebSocketConfig::default();
385 let state = WebSocketState::new(config);
386
387 let message = ServerMessage::Pong;
388 let result = state.broadcast_to_all(message).await;
389 assert!(result.is_ok());
390 }
391}