hammerwork_web/
websocket.rs

1//! WebSocket implementation for real-time dashboard updates.
2//!
3//! This module provides WebSocket functionality for real-time communication between
4//! the dashboard frontend and backend. It supports connection management, message
5//! broadcasting, and automatic ping/pong for connection health.
6//!
7//! # Message Types
8//!
9//! The WebSocket API supports several message types for different events:
10//!
11//! ```rust
12//! use hammerwork_web::websocket::{ClientMessage, ServerMessage, AlertSeverity};
13//! use serde_json::json;
14//!
15//! // Client messages (sent from browser to server)
16//! let subscribe_msg = ClientMessage::Subscribe {
17//!     event_types: vec!["queue_updates".to_string(), "job_updates".to_string()],
18//! };
19//!
20//! let ping_msg = ClientMessage::Ping;
21//!
22//! // Server messages (sent from server to browser)
23//! let alert_msg = ServerMessage::SystemAlert {
24//!     message: "High error rate detected".to_string(),
25//!     severity: AlertSeverity::Warning,
26//! };
27//!
28//! let pong_msg = ServerMessage::Pong;
29//! ```
30//!
31//! # Connection Management
32//!
33//! ```rust
34//! use hammerwork_web::websocket::WebSocketState;
35//! use hammerwork_web::config::WebSocketConfig;
36//!
37//! let config = WebSocketConfig::default();
38//! let ws_state = WebSocketState::new(config);
39//!
40//! assert_eq!(ws_state.connection_count(), 0);
41//! ```
42
43use crate::config::WebSocketConfig;
44use futures_util::{SinkExt, StreamExt};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47use tokio::sync::mpsc;
48use tracing::{debug, error, info, warn};
49use uuid::Uuid;
50use warp::ws::Message;
51
52/// WebSocket connection state manager
53#[derive(Debug)]
54pub struct WebSocketState {
55    config: WebSocketConfig,
56    connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
57    broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
58    broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
59}
60
61impl WebSocketState {
62    pub fn new(config: WebSocketConfig) -> Self {
63        let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
64
65        Self {
66            config,
67            connections: HashMap::new(),
68            broadcast_sender,
69            broadcast_receiver: Some(broadcast_receiver),
70        }
71    }
72
73    /// Handle a new WebSocket connection
74    pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
75        let connection_id = Uuid::new_v4();
76
77        if self.connections.len() >= self.config.max_connections {
78            warn!("Maximum WebSocket connections reached, rejecting new connection");
79            return Ok(());
80        }
81
82        let (mut ws_sender, mut ws_receiver) = websocket.split();
83        let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
84
85        // Store the connection
86        self.connections.insert(connection_id, tx);
87        info!("WebSocket connection established: {}", connection_id);
88
89        // Spawn task to handle outgoing messages to this client
90        let connection_id_clone = connection_id;
91        tokio::spawn(async move {
92            while let Some(message) = rx.recv().await {
93                if let Err(e) = ws_sender.send(message).await {
94                    debug!(
95                        "Failed to send WebSocket message to {}: {}",
96                        connection_id_clone, e
97                    );
98                    break;
99                }
100            }
101        });
102
103        // Handle incoming messages from this client
104        let broadcast_sender = self.broadcast_sender.clone();
105
106        while let Some(result) = ws_receiver.next().await {
107            match result {
108                Ok(message) => {
109                    if let Err(e) = self
110                        .handle_client_message(connection_id, message, &broadcast_sender)
111                        .await
112                    {
113                        error!(
114                            "Error handling client message from {}: {}",
115                            connection_id, e
116                        );
117                        break;
118                    }
119                }
120                Err(e) => {
121                    debug!("WebSocket error for connection {}: {}", connection_id, e);
122                    break;
123                }
124            }
125        }
126
127        // Clean up connection
128        self.connections.remove(&connection_id);
129        info!("WebSocket connection closed: {}", connection_id);
130
131        Ok(())
132    }
133
134    /// Handle a message from a client
135    async fn handle_client_message(
136        &self,
137        connection_id: Uuid,
138        message: Message,
139        _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
140    ) -> crate::Result<()> {
141        if message.is_text() {
142            if let Ok(text) = message.to_str() {
143                if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
144                    debug!(
145                        "Received message from {}: {:?}",
146                        connection_id, client_message
147                    );
148                    self.handle_client_action(connection_id, client_message)
149                        .await?;
150                } else {
151                    warn!("Invalid message format from {}: {}", connection_id, text);
152                }
153            }
154        } else if message.is_ping() {
155            // Send pong response
156            if let Some(sender) = self.connections.get(&connection_id) {
157                let pong_msg = Message::pong(message.as_bytes());
158                let _ = sender.send(pong_msg);
159            }
160        } else if message.is_pong() {
161            // Pong received - connection is alive
162            debug!("Pong received from {}", connection_id);
163        } else if message.is_close() {
164            debug!("Close message received from {}", connection_id);
165        } else if message.is_binary() {
166            warn!("Binary message not supported from {}", connection_id);
167        }
168
169        Ok(())
170    }
171
172    /// Handle a client action
173    async fn handle_client_action(
174        &self,
175        _connection_id: Uuid,
176        message: ClientMessage,
177    ) -> crate::Result<()> {
178        match message {
179            ClientMessage::Subscribe { event_types } => {
180                info!("Client subscribed to events: {:?}", event_types);
181                // TODO: Store subscription preferences per connection
182            }
183            ClientMessage::Unsubscribe { event_types } => {
184                info!("Client unsubscribed from events: {:?}", event_types);
185                // TODO: Update subscription preferences
186            }
187            ClientMessage::Ping => {
188                // Client ping - we'll send a pong back via broadcast
189                self.broadcast_to_all(ServerMessage::Pong).await?;
190            }
191        }
192
193        Ok(())
194    }
195
196    /// Broadcast a message to all connected clients
197    pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
198        let json_message = serde_json::to_string(&message)?;
199        let ws_message = Message::text(json_message);
200
201        let mut disconnected = Vec::new();
202
203        for (&connection_id, sender) in &self.connections {
204            if sender.send(ws_message.clone()).is_err() {
205                disconnected.push(connection_id);
206            }
207        }
208
209        // Clean up disconnected clients
210        // Note: In a real implementation, we'd need mutable access to self.connections
211        // This would be handled by the connection cleanup in handle_connection
212
213        Ok(())
214    }
215
216    /// Send ping to all connections to keep them alive
217    pub async fn ping_all_connections(&self) {
218        let ping_message = Message::ping(b"ping");
219        let mut disconnected = Vec::new();
220
221        for (&connection_id, sender) in &self.connections {
222            if sender.send(ping_message.clone()).is_err() {
223                disconnected.push(connection_id);
224            }
225        }
226
227        if !disconnected.is_empty() {
228            debug!(
229                "Detected {} disconnected WebSocket clients during ping",
230                disconnected.len()
231            );
232        }
233    }
234
235    /// Get current connection count
236    pub fn connection_count(&self) -> usize {
237        self.connections.len()
238    }
239
240    /// Start the broadcast listener task
241    pub async fn start_broadcast_listener(&mut self) -> crate::Result<()> {
242        if let Some(mut receiver) = self.broadcast_receiver.take() {
243            tokio::spawn(async move {
244                while let Some(broadcast_message) = receiver.recv().await {
245                    // Convert broadcast message to server message and send to all clients
246                    let server_message = match broadcast_message {
247                        BroadcastMessage::QueueUpdate { queue_name, stats } => {
248                            ServerMessage::QueueUpdate { queue_name, stats }
249                        }
250                        BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
251                        BroadcastMessage::SystemAlert { message, severity } => {
252                            ServerMessage::SystemAlert { message, severity }
253                        }
254                    };
255
256                    // Note: We'd need access to the WebSocketState here to broadcast
257                    // In a real implementation, this would be handled differently
258                    debug!("Broadcasting message: {:?}", server_message);
259                }
260            });
261        }
262        Ok(())
263    }
264}
265
266/// Messages sent from client to server
267#[derive(Debug, Deserialize)]
268#[serde(tag = "type")]
269pub enum ClientMessage {
270    Subscribe { event_types: Vec<String> },
271    Unsubscribe { event_types: Vec<String> },
272    Ping,
273}
274
275/// Messages sent from server to client
276#[derive(Debug, Serialize)]
277#[serde(tag = "type")]
278pub enum ServerMessage {
279    QueueUpdate {
280        queue_name: String,
281        stats: QueueStats,
282    },
283    JobUpdate {
284        job: JobUpdate,
285    },
286    SystemAlert {
287        message: String,
288        severity: AlertSeverity,
289    },
290    Pong,
291}
292
293/// Internal broadcast messages
294#[derive(Debug)]
295pub enum BroadcastMessage {
296    QueueUpdate {
297        queue_name: String,
298        stats: QueueStats,
299    },
300    JobUpdate {
301        job: JobUpdate,
302    },
303    SystemAlert {
304        message: String,
305        severity: AlertSeverity,
306    },
307}
308
309/// Queue statistics for WebSocket updates
310#[derive(Debug, Serialize)]
311pub struct QueueStats {
312    pub pending_count: u64,
313    pub running_count: u64,
314    pub completed_count: u64,
315    pub failed_count: u64,
316    pub dead_count: u64,
317    pub throughput_per_minute: f64,
318    pub avg_processing_time_ms: f64,
319    pub error_rate: f64,
320    pub updated_at: chrono::DateTime<chrono::Utc>,
321}
322
323/// Job update information
324#[derive(Debug, Serialize)]
325pub struct JobUpdate {
326    pub id: String,
327    pub queue_name: String,
328    pub status: String,
329    pub priority: String,
330    pub attempts: i32,
331    pub updated_at: chrono::DateTime<chrono::Utc>,
332}
333
334/// Alert severity levels
335#[derive(Debug, Serialize)]
336pub enum AlertSeverity {
337    Info,
338    Warning,
339    Error,
340    Critical,
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::config::WebSocketConfig;
347
348    #[test]
349    fn test_websocket_state_creation() {
350        let config = WebSocketConfig::default();
351        let state = WebSocketState::new(config);
352        assert_eq!(state.connection_count(), 0);
353    }
354
355    #[test]
356    fn test_client_message_deserialization() {
357        let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
358        let message: ClientMessage = serde_json::from_str(json).unwrap();
359
360        match message {
361            ClientMessage::Subscribe { event_types } => {
362                assert_eq!(event_types.len(), 2);
363                assert!(event_types.contains(&"queue_updates".to_string()));
364            }
365            _ => panic!("Wrong message type"),
366        }
367    }
368
369    #[test]
370    fn test_server_message_serialization() {
371        let message = ServerMessage::SystemAlert {
372            message: "High error rate detected".to_string(),
373            severity: AlertSeverity::Warning,
374        };
375
376        let json = serde_json::to_string(&message).unwrap();
377        assert!(json.contains("type"));
378        assert!(json.contains("SystemAlert"));
379        assert!(json.contains("High error rate detected"));
380    }
381
382    #[tokio::test]
383    async fn test_broadcast_to_all() {
384        let config = WebSocketConfig::default();
385        let state = WebSocketState::new(config);
386
387        let message = ServerMessage::Pong;
388        let result = state.broadcast_to_all(message).await;
389        assert!(result.is_ok());
390    }
391}