hammerwork_web/
websocket.rs

1//! WebSocket implementation for real-time dashboard updates.
2//!
3//! This module provides WebSocket functionality for real-time communication between
4//! the dashboard frontend and backend. It supports connection management, message
5//! broadcasting, and automatic ping/pong for connection health.
6//!
7//! # Message Types
8//!
9//! The WebSocket API supports several message types for different events:
10//!
11//! ```rust
12//! use hammerwork_web::websocket::{ClientMessage, ServerMessage, AlertSeverity};
13//! use serde_json::json;
14//!
15//! // Client messages (sent from browser to server)
16//! let subscribe_msg = ClientMessage::Subscribe {
17//!     event_types: vec!["queue_updates".to_string(), "job_updates".to_string()],
18//! };
19//!
20//! let ping_msg = ClientMessage::Ping;
21//!
22//! // Server messages (sent from server to browser)
23//! let alert_msg = ServerMessage::SystemAlert {
24//!     message: "High error rate detected".to_string(),
25//!     severity: AlertSeverity::Warning,
26//! };
27//!
28//! let pong_msg = ServerMessage::Pong;
29//! ```
30//!
31//! # Connection Management
32//!
33//! ```rust
34//! use hammerwork_web::websocket::WebSocketState;
35//! use hammerwork_web::config::WebSocketConfig;
36//!
37//! let config = WebSocketConfig::default();
38//! let ws_state = WebSocketState::new(config);
39//!
40//! assert_eq!(ws_state.connection_count(), 0);
41//! ```
42
43use crate::config::WebSocketConfig;
44use chrono::{DateTime, Utc};
45use futures_util::{SinkExt, StreamExt};
46pub use hammerwork::archive::{ArchivalReason, ArchivalStats};
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use tokio::sync::mpsc;
50use tracing::{debug, error, info, warn};
51use uuid::Uuid;
52use warp::ws::Message;
53
54/// WebSocket connection state manager
55#[derive(Debug)]
56pub struct WebSocketState {
57    config: WebSocketConfig,
58    connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
59    broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
60    broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
61}
62
63impl WebSocketState {
64    pub fn new(config: WebSocketConfig) -> Self {
65        let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
66
67        Self {
68            config,
69            connections: HashMap::new(),
70            broadcast_sender,
71            broadcast_receiver: Some(broadcast_receiver),
72        }
73    }
74
75    /// Handle a new WebSocket connection
76    pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
77        let connection_id = Uuid::new_v4();
78
79        if self.connections.len() >= self.config.max_connections {
80            warn!("Maximum WebSocket connections reached, rejecting new connection");
81            return Ok(());
82        }
83
84        let (mut ws_sender, mut ws_receiver) = websocket.split();
85        let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
86
87        // Store the connection
88        self.connections.insert(connection_id, tx);
89        info!("WebSocket connection established: {}", connection_id);
90
91        // Spawn task to handle outgoing messages to this client
92        let connection_id_clone = connection_id;
93        tokio::spawn(async move {
94            while let Some(message) = rx.recv().await {
95                if let Err(e) = ws_sender.send(message).await {
96                    debug!(
97                        "Failed to send WebSocket message to {}: {}",
98                        connection_id_clone, e
99                    );
100                    break;
101                }
102            }
103        });
104
105        // Handle incoming messages from this client
106        let broadcast_sender = self.broadcast_sender.clone();
107
108        while let Some(result) = ws_receiver.next().await {
109            match result {
110                Ok(message) => {
111                    if let Err(e) = self
112                        .handle_client_message(connection_id, message, &broadcast_sender)
113                        .await
114                    {
115                        error!(
116                            "Error handling client message from {}: {}",
117                            connection_id, e
118                        );
119                        break;
120                    }
121                }
122                Err(e) => {
123                    debug!("WebSocket error for connection {}: {}", connection_id, e);
124                    break;
125                }
126            }
127        }
128
129        // Clean up connection
130        self.connections.remove(&connection_id);
131        info!("WebSocket connection closed: {}", connection_id);
132
133        Ok(())
134    }
135
136    /// Handle a message from a client
137    async fn handle_client_message(
138        &self,
139        connection_id: Uuid,
140        message: Message,
141        _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
142    ) -> crate::Result<()> {
143        if message.is_text() {
144            if let Ok(text) = message.to_str() {
145                if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
146                    debug!(
147                        "Received message from {}: {:?}",
148                        connection_id, client_message
149                    );
150                    self.handle_client_action(connection_id, client_message)
151                        .await?;
152                } else {
153                    warn!("Invalid message format from {}: {}", connection_id, text);
154                }
155            }
156        } else if message.is_ping() {
157            // Send pong response
158            if let Some(sender) = self.connections.get(&connection_id) {
159                let pong_msg = Message::pong(message.as_bytes());
160                let _ = sender.send(pong_msg);
161            }
162        } else if message.is_pong() {
163            // Pong received - connection is alive
164            debug!("Pong received from {}", connection_id);
165        } else if message.is_close() {
166            debug!("Close message received from {}", connection_id);
167        } else if message.is_binary() {
168            warn!("Binary message not supported from {}", connection_id);
169        }
170
171        Ok(())
172    }
173
174    /// Handle a client action
175    async fn handle_client_action(
176        &self,
177        _connection_id: Uuid,
178        message: ClientMessage,
179    ) -> crate::Result<()> {
180        match message {
181            ClientMessage::Subscribe { event_types } => {
182                info!("Client subscribed to events: {:?}", event_types);
183                // TODO: Store subscription preferences per connection
184            }
185            ClientMessage::Unsubscribe { event_types } => {
186                info!("Client unsubscribed from events: {:?}", event_types);
187                // TODO: Update subscription preferences
188            }
189            ClientMessage::Ping => {
190                // Client ping - we'll send a pong back via broadcast
191                self.broadcast_to_all(ServerMessage::Pong).await?;
192            }
193        }
194
195        Ok(())
196    }
197
198    /// Broadcast a message to all connected clients
199    pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
200        let json_message = serde_json::to_string(&message)?;
201        let ws_message = Message::text(json_message);
202
203        let mut disconnected = Vec::new();
204
205        for (&connection_id, sender) in &self.connections {
206            if sender.send(ws_message.clone()).is_err() {
207                disconnected.push(connection_id);
208            }
209        }
210
211        // Clean up disconnected clients
212        // Note: In a real implementation, we'd need mutable access to self.connections
213        // This would be handled by the connection cleanup in handle_connection
214
215        Ok(())
216    }
217
218    /// Publish an archive event to all connected clients
219    pub async fn publish_archive_event(
220        &self,
221        event: hammerwork::archive::ArchiveEvent,
222    ) -> crate::Result<()> {
223        let broadcast_message = match event {
224            hammerwork::archive::ArchiveEvent::JobArchived {
225                job_id,
226                queue,
227                reason,
228            } => BroadcastMessage::JobArchived {
229                job_id: job_id.to_string(),
230                queue,
231                reason,
232            },
233            hammerwork::archive::ArchiveEvent::JobRestored {
234                job_id,
235                queue,
236                restored_by,
237            } => BroadcastMessage::JobRestored {
238                job_id: job_id.to_string(),
239                queue,
240                restored_by,
241            },
242            hammerwork::archive::ArchiveEvent::BulkArchiveStarted {
243                operation_id,
244                estimated_jobs,
245            } => BroadcastMessage::BulkArchiveStarted {
246                operation_id,
247                estimated_jobs,
248            },
249            hammerwork::archive::ArchiveEvent::BulkArchiveProgress {
250                operation_id,
251                jobs_processed,
252                total,
253            } => BroadcastMessage::BulkArchiveProgress {
254                operation_id,
255                jobs_processed,
256                total,
257            },
258            hammerwork::archive::ArchiveEvent::BulkArchiveCompleted {
259                operation_id,
260                stats,
261            } => BroadcastMessage::BulkArchiveCompleted {
262                operation_id,
263                stats,
264            },
265            hammerwork::archive::ArchiveEvent::JobsPurged { count, older_than } => {
266                BroadcastMessage::JobsPurged { count, older_than }
267            }
268        };
269
270        // Send to the broadcast channel
271        if let Err(_) = self.broadcast_sender.send(broadcast_message) {
272            return Err(anyhow::anyhow!(
273                "Failed to send archive event to broadcast channel"
274            ));
275        }
276
277        Ok(())
278    }
279
280    /// Send ping to all connections to keep them alive
281    pub async fn ping_all_connections(&self) {
282        let ping_message = Message::ping(b"ping");
283        let mut disconnected = Vec::new();
284
285        for (&connection_id, sender) in &self.connections {
286            if sender.send(ping_message.clone()).is_err() {
287                disconnected.push(connection_id);
288            }
289        }
290
291        if !disconnected.is_empty() {
292            debug!(
293                "Detected {} disconnected WebSocket clients during ping",
294                disconnected.len()
295            );
296        }
297    }
298
299    /// Get current connection count
300    pub fn connection_count(&self) -> usize {
301        self.connections.len()
302    }
303
304    /// Start the broadcast listener task
305    pub async fn start_broadcast_listener(&mut self) -> crate::Result<()> {
306        if let Some(mut receiver) = self.broadcast_receiver.take() {
307            tokio::spawn(async move {
308                while let Some(broadcast_message) = receiver.recv().await {
309                    // Convert broadcast message to server message and send to all clients
310                    let server_message = match broadcast_message {
311                        BroadcastMessage::QueueUpdate { queue_name, stats } => {
312                            ServerMessage::QueueUpdate { queue_name, stats }
313                        }
314                        BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
315                        BroadcastMessage::SystemAlert { message, severity } => {
316                            ServerMessage::SystemAlert { message, severity }
317                        }
318                        BroadcastMessage::JobArchived {
319                            job_id,
320                            queue,
321                            reason,
322                        } => ServerMessage::JobArchived {
323                            job_id,
324                            queue,
325                            reason,
326                        },
327                        BroadcastMessage::JobRestored {
328                            job_id,
329                            queue,
330                            restored_by,
331                        } => ServerMessage::JobRestored {
332                            job_id,
333                            queue,
334                            restored_by,
335                        },
336                        BroadcastMessage::BulkArchiveStarted {
337                            operation_id,
338                            estimated_jobs,
339                        } => ServerMessage::BulkArchiveStarted {
340                            operation_id,
341                            estimated_jobs,
342                        },
343                        BroadcastMessage::BulkArchiveProgress {
344                            operation_id,
345                            jobs_processed,
346                            total,
347                        } => ServerMessage::BulkArchiveProgress {
348                            operation_id,
349                            jobs_processed,
350                            total,
351                        },
352                        BroadcastMessage::BulkArchiveCompleted {
353                            operation_id,
354                            stats,
355                        } => ServerMessage::BulkArchiveCompleted {
356                            operation_id,
357                            stats,
358                        },
359                        BroadcastMessage::JobsPurged { count, older_than } => {
360                            ServerMessage::JobsPurged { count, older_than }
361                        }
362                    };
363
364                    // Note: We'd need access to the WebSocketState here to broadcast
365                    // In a real implementation, this would be handled differently
366                    debug!("Broadcasting message: {:?}", server_message);
367                }
368            });
369        }
370        Ok(())
371    }
372}
373
374/// Messages sent from client to server
375#[derive(Debug, Deserialize)]
376#[serde(tag = "type")]
377pub enum ClientMessage {
378    Subscribe { event_types: Vec<String> },
379    Unsubscribe { event_types: Vec<String> },
380    Ping,
381}
382
383/// Messages sent from server to client
384#[derive(Debug, Serialize)]
385#[serde(tag = "type")]
386pub enum ServerMessage {
387    QueueUpdate {
388        queue_name: String,
389        stats: QueueStats,
390    },
391    JobUpdate {
392        job: JobUpdate,
393    },
394    SystemAlert {
395        message: String,
396        severity: AlertSeverity,
397    },
398    JobArchived {
399        job_id: String,
400        queue: String,
401        reason: ArchivalReason,
402    },
403    JobRestored {
404        job_id: String,
405        queue: String,
406        restored_by: Option<String>,
407    },
408    BulkArchiveStarted {
409        operation_id: String,
410        estimated_jobs: u64,
411    },
412    BulkArchiveProgress {
413        operation_id: String,
414        jobs_processed: u64,
415        total: u64,
416    },
417    BulkArchiveCompleted {
418        operation_id: String,
419        stats: ArchivalStats,
420    },
421    JobsPurged {
422        count: u64,
423        older_than: DateTime<Utc>,
424    },
425    Pong,
426}
427
428/// Internal broadcast messages
429#[derive(Debug)]
430pub enum BroadcastMessage {
431    QueueUpdate {
432        queue_name: String,
433        stats: QueueStats,
434    },
435    JobUpdate {
436        job: JobUpdate,
437    },
438    SystemAlert {
439        message: String,
440        severity: AlertSeverity,
441    },
442    JobArchived {
443        job_id: String,
444        queue: String,
445        reason: ArchivalReason,
446    },
447    JobRestored {
448        job_id: String,
449        queue: String,
450        restored_by: Option<String>,
451    },
452    BulkArchiveStarted {
453        operation_id: String,
454        estimated_jobs: u64,
455    },
456    BulkArchiveProgress {
457        operation_id: String,
458        jobs_processed: u64,
459        total: u64,
460    },
461    BulkArchiveCompleted {
462        operation_id: String,
463        stats: ArchivalStats,
464    },
465    JobsPurged {
466        count: u64,
467        older_than: DateTime<Utc>,
468    },
469}
470
471/// Queue statistics for WebSocket updates
472#[derive(Debug, Serialize)]
473pub struct QueueStats {
474    pub pending_count: u64,
475    pub running_count: u64,
476    pub completed_count: u64,
477    pub failed_count: u64,
478    pub dead_count: u64,
479    pub throughput_per_minute: f64,
480    pub avg_processing_time_ms: f64,
481    pub error_rate: f64,
482    pub updated_at: chrono::DateTime<chrono::Utc>,
483}
484
485/// Job update information
486#[derive(Debug, Serialize)]
487pub struct JobUpdate {
488    pub id: String,
489    pub queue_name: String,
490    pub status: String,
491    pub priority: String,
492    pub attempts: i32,
493    pub updated_at: chrono::DateTime<chrono::Utc>,
494}
495
496/// Alert severity levels
497#[derive(Debug, Serialize)]
498pub enum AlertSeverity {
499    Info,
500    Warning,
501    Error,
502    Critical,
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use crate::config::WebSocketConfig;
509
510    #[test]
511    fn test_websocket_state_creation() {
512        let config = WebSocketConfig::default();
513        let state = WebSocketState::new(config);
514        assert_eq!(state.connection_count(), 0);
515    }
516
517    #[test]
518    fn test_client_message_deserialization() {
519        let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
520        let message: ClientMessage = serde_json::from_str(json).unwrap();
521
522        match message {
523            ClientMessage::Subscribe { event_types } => {
524                assert_eq!(event_types.len(), 2);
525                assert!(event_types.contains(&"queue_updates".to_string()));
526            }
527            _ => panic!("Wrong message type"),
528        }
529    }
530
531    #[test]
532    fn test_server_message_serialization() {
533        let message = ServerMessage::SystemAlert {
534            message: "High error rate detected".to_string(),
535            severity: AlertSeverity::Warning,
536        };
537
538        let json = serde_json::to_string(&message).unwrap();
539        assert!(json.contains("type"));
540        assert!(json.contains("SystemAlert"));
541        assert!(json.contains("High error rate detected"));
542    }
543
544    #[tokio::test]
545    async fn test_broadcast_to_all() {
546        let config = WebSocketConfig::default();
547        let state = WebSocketState::new(config);
548
549        let message = ServerMessage::Pong;
550        let result = state.broadcast_to_all(message).await;
551        assert!(result.is_ok());
552    }
553}