1use crate::config::WebSocketConfig;
44use chrono::{DateTime, Utc};
45use futures_util::{SinkExt, StreamExt};
46pub use hammerwork::archive::{ArchivalReason, ArchivalStats};
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tokio::sync::mpsc;
50use tracing::{debug, error, info, warn};
51use uuid::Uuid;
52use warp::ws::Message;
53
54#[derive(Debug)]
56pub struct WebSocketState {
57 config: WebSocketConfig,
58 connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
59 broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
60 broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
61}
62
63impl WebSocketState {
64 pub fn new(config: WebSocketConfig) -> Self {
65 let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
66
67 Self {
68 config,
69 connections: HashMap::new(),
70 broadcast_sender,
71 broadcast_receiver: Some(broadcast_receiver),
72 }
73 }
74
75 pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
77 let connection_id = Uuid::new_v4();
78
79 if self.connections.len() >= self.config.max_connections {
80 warn!("Maximum WebSocket connections reached, rejecting new connection");
81 return Ok(());
82 }
83
84 let (mut ws_sender, mut ws_receiver) = websocket.split();
85 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
86
87 self.connections.insert(connection_id, tx);
89 info!("WebSocket connection established: {}", connection_id);
90
91 let connection_id_clone = connection_id;
93 tokio::spawn(async move {
94 while let Some(message) = rx.recv().await {
95 if let Err(e) = ws_sender.send(message).await {
96 debug!(
97 "Failed to send WebSocket message to {}: {}",
98 connection_id_clone, e
99 );
100 break;
101 }
102 }
103 });
104
105 let broadcast_sender = self.broadcast_sender.clone();
107
108 while let Some(result) = ws_receiver.next().await {
109 match result {
110 Ok(message) => {
111 if let Err(e) = self
112 .handle_client_message(connection_id, message, &broadcast_sender)
113 .await
114 {
115 error!(
116 "Error handling client message from {}: {}",
117 connection_id, e
118 );
119 break;
120 }
121 }
122 Err(e) => {
123 debug!("WebSocket error for connection {}: {}", connection_id, e);
124 break;
125 }
126 }
127 }
128
129 self.connections.remove(&connection_id);
131 info!("WebSocket connection closed: {}", connection_id);
132
133 Ok(())
134 }
135
136 async fn handle_client_message(
138 &self,
139 connection_id: Uuid,
140 message: Message,
141 _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
142 ) -> crate::Result<()> {
143 if message.is_text() {
144 if let Ok(text) = message.to_str() {
145 if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
146 debug!(
147 "Received message from {}: {:?}",
148 connection_id, client_message
149 );
150 self.handle_client_action(connection_id, client_message)
151 .await?;
152 } else {
153 warn!("Invalid message format from {}: {}", connection_id, text);
154 }
155 }
156 } else if message.is_ping() {
157 if let Some(sender) = self.connections.get(&connection_id) {
159 let pong_msg = Message::pong(message.as_bytes());
160 let _ = sender.send(pong_msg);
161 }
162 } else if message.is_pong() {
163 debug!("Pong received from {}", connection_id);
165 } else if message.is_close() {
166 debug!("Close message received from {}", connection_id);
167 } else if message.is_binary() {
168 warn!("Binary message not supported from {}", connection_id);
169 }
170
171 Ok(())
172 }
173
174 async fn handle_client_action(
176 &self,
177 _connection_id: Uuid,
178 message: ClientMessage,
179 ) -> crate::Result<()> {
180 match message {
181 ClientMessage::Subscribe { event_types } => {
182 info!("Client subscribed to events: {:?}", event_types);
183 }
185 ClientMessage::Unsubscribe { event_types } => {
186 info!("Client unsubscribed from events: {:?}", event_types);
187 }
189 ClientMessage::Ping => {
190 self.broadcast_to_all(ServerMessage::Pong).await?;
192 }
193 }
194
195 Ok(())
196 }
197
198 pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
200 let json_message = serde_json::to_string(&message)?;
201 let ws_message = Message::text(json_message);
202
203 let mut disconnected = Vec::new();
204
205 for (&connection_id, sender) in &self.connections {
206 if sender.send(ws_message.clone()).is_err() {
207 disconnected.push(connection_id);
208 }
209 }
210
211 Ok(())
216 }
217
218 pub async fn publish_archive_event(
220 &self,
221 event: hammerwork::archive::ArchiveEvent,
222 ) -> crate::Result<()> {
223 let broadcast_message = match event {
224 hammerwork::archive::ArchiveEvent::JobArchived {
225 job_id,
226 queue,
227 reason,
228 } => BroadcastMessage::JobArchived {
229 job_id: job_id.to_string(),
230 queue,
231 reason,
232 },
233 hammerwork::archive::ArchiveEvent::JobRestored {
234 job_id,
235 queue,
236 restored_by,
237 } => BroadcastMessage::JobRestored {
238 job_id: job_id.to_string(),
239 queue,
240 restored_by,
241 },
242 hammerwork::archive::ArchiveEvent::BulkArchiveStarted {
243 operation_id,
244 estimated_jobs,
245 } => BroadcastMessage::BulkArchiveStarted {
246 operation_id,
247 estimated_jobs,
248 },
249 hammerwork::archive::ArchiveEvent::BulkArchiveProgress {
250 operation_id,
251 jobs_processed,
252 total,
253 } => BroadcastMessage::BulkArchiveProgress {
254 operation_id,
255 jobs_processed,
256 total,
257 },
258 hammerwork::archive::ArchiveEvent::BulkArchiveCompleted {
259 operation_id,
260 stats,
261 } => BroadcastMessage::BulkArchiveCompleted {
262 operation_id,
263 stats,
264 },
265 hammerwork::archive::ArchiveEvent::JobsPurged { count, older_than } => {
266 BroadcastMessage::JobsPurged { count, older_than }
267 }
268 };
269
270 if let Err(_) = self.broadcast_sender.send(broadcast_message) {
272 return Err(anyhow::anyhow!(
273 "Failed to send archive event to broadcast channel"
274 ));
275 }
276
277 Ok(())
278 }
279
280 pub async fn ping_all_connections(&self) {
282 let ping_message = Message::ping(b"ping");
283 let mut disconnected = Vec::new();
284
285 for (&connection_id, sender) in &self.connections {
286 if sender.send(ping_message.clone()).is_err() {
287 disconnected.push(connection_id);
288 }
289 }
290
291 if !disconnected.is_empty() {
292 debug!(
293 "Detected {} disconnected WebSocket clients during ping",
294 disconnected.len()
295 );
296 }
297 }
298
299 pub fn connection_count(&self) -> usize {
301 self.connections.len()
302 }
303
304 pub async fn start_broadcast_listener(&mut self) -> crate::Result<()> {
306 if let Some(mut receiver) = self.broadcast_receiver.take() {
307 tokio::spawn(async move {
308 while let Some(broadcast_message) = receiver.recv().await {
309 let server_message = match broadcast_message {
311 BroadcastMessage::QueueUpdate { queue_name, stats } => {
312 ServerMessage::QueueUpdate { queue_name, stats }
313 }
314 BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
315 BroadcastMessage::SystemAlert { message, severity } => {
316 ServerMessage::SystemAlert { message, severity }
317 }
318 BroadcastMessage::JobArchived {
319 job_id,
320 queue,
321 reason,
322 } => ServerMessage::JobArchived {
323 job_id,
324 queue,
325 reason,
326 },
327 BroadcastMessage::JobRestored {
328 job_id,
329 queue,
330 restored_by,
331 } => ServerMessage::JobRestored {
332 job_id,
333 queue,
334 restored_by,
335 },
336 BroadcastMessage::BulkArchiveStarted {
337 operation_id,
338 estimated_jobs,
339 } => ServerMessage::BulkArchiveStarted {
340 operation_id,
341 estimated_jobs,
342 },
343 BroadcastMessage::BulkArchiveProgress {
344 operation_id,
345 jobs_processed,
346 total,
347 } => ServerMessage::BulkArchiveProgress {
348 operation_id,
349 jobs_processed,
350 total,
351 },
352 BroadcastMessage::BulkArchiveCompleted {
353 operation_id,
354 stats,
355 } => ServerMessage::BulkArchiveCompleted {
356 operation_id,
357 stats,
358 },
359 BroadcastMessage::JobsPurged { count, older_than } => {
360 ServerMessage::JobsPurged { count, older_than }
361 }
362 };
363
364 debug!("Broadcasting message: {:?}", server_message);
367 }
368 });
369 }
370 Ok(())
371 }
372}
373
374#[derive(Debug, Deserialize)]
376#[serde(tag = "type")]
377pub enum ClientMessage {
378 Subscribe { event_types: Vec<String> },
379 Unsubscribe { event_types: Vec<String> },
380 Ping,
381}
382
383#[derive(Debug, Serialize)]
385#[serde(tag = "type")]
386pub enum ServerMessage {
387 QueueUpdate {
388 queue_name: String,
389 stats: QueueStats,
390 },
391 JobUpdate {
392 job: JobUpdate,
393 },
394 SystemAlert {
395 message: String,
396 severity: AlertSeverity,
397 },
398 JobArchived {
399 job_id: String,
400 queue: String,
401 reason: ArchivalReason,
402 },
403 JobRestored {
404 job_id: String,
405 queue: String,
406 restored_by: Option<String>,
407 },
408 BulkArchiveStarted {
409 operation_id: String,
410 estimated_jobs: u64,
411 },
412 BulkArchiveProgress {
413 operation_id: String,
414 jobs_processed: u64,
415 total: u64,
416 },
417 BulkArchiveCompleted {
418 operation_id: String,
419 stats: ArchivalStats,
420 },
421 JobsPurged {
422 count: u64,
423 older_than: DateTime<Utc>,
424 },
425 Pong,
426}
427
428#[derive(Debug)]
430pub enum BroadcastMessage {
431 QueueUpdate {
432 queue_name: String,
433 stats: QueueStats,
434 },
435 JobUpdate {
436 job: JobUpdate,
437 },
438 SystemAlert {
439 message: String,
440 severity: AlertSeverity,
441 },
442 JobArchived {
443 job_id: String,
444 queue: String,
445 reason: ArchivalReason,
446 },
447 JobRestored {
448 job_id: String,
449 queue: String,
450 restored_by: Option<String>,
451 },
452 BulkArchiveStarted {
453 operation_id: String,
454 estimated_jobs: u64,
455 },
456 BulkArchiveProgress {
457 operation_id: String,
458 jobs_processed: u64,
459 total: u64,
460 },
461 BulkArchiveCompleted {
462 operation_id: String,
463 stats: ArchivalStats,
464 },
465 JobsPurged {
466 count: u64,
467 older_than: DateTime<Utc>,
468 },
469}
470
471#[derive(Debug, Serialize)]
473pub struct QueueStats {
474 pub pending_count: u64,
475 pub running_count: u64,
476 pub completed_count: u64,
477 pub failed_count: u64,
478 pub dead_count: u64,
479 pub throughput_per_minute: f64,
480 pub avg_processing_time_ms: f64,
481 pub error_rate: f64,
482 pub updated_at: chrono::DateTime<chrono::Utc>,
483}
484
485#[derive(Debug, Serialize)]
487pub struct JobUpdate {
488 pub id: String,
489 pub queue_name: String,
490 pub status: String,
491 pub priority: String,
492 pub attempts: i32,
493 pub updated_at: chrono::DateTime<chrono::Utc>,
494}
495
496#[derive(Debug, Serialize)]
498pub enum AlertSeverity {
499 Info,
500 Warning,
501 Error,
502 Critical,
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use crate::config::WebSocketConfig;
509
510 #[test]
511 fn test_websocket_state_creation() {
512 let config = WebSocketConfig::default();
513 let state = WebSocketState::new(config);
514 assert_eq!(state.connection_count(), 0);
515 }
516
517 #[test]
518 fn test_client_message_deserialization() {
519 let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
520 let message: ClientMessage = serde_json::from_str(json).unwrap();
521
522 match message {
523 ClientMessage::Subscribe { event_types } => {
524 assert_eq!(event_types.len(), 2);
525 assert!(event_types.contains(&"queue_updates".to_string()));
526 }
527 _ => panic!("Wrong message type"),
528 }
529 }
530
531 #[test]
532 fn test_server_message_serialization() {
533 let message = ServerMessage::SystemAlert {
534 message: "High error rate detected".to_string(),
535 severity: AlertSeverity::Warning,
536 };
537
538 let json = serde_json::to_string(&message).unwrap();
539 assert!(json.contains("type"));
540 assert!(json.contains("SystemAlert"));
541 assert!(json.contains("High error rate detected"));
542 }
543
544 #[tokio::test]
545 async fn test_broadcast_to_all() {
546 let config = WebSocketConfig::default();
547 let state = WebSocketState::new(config);
548
549 let message = ServerMessage::Pong;
550 let result = state.broadcast_to_all(message).await;
551 assert!(result.is_ok());
552 }
553}