Skip to main content

hammerwork_web/
websocket.rs

1//! WebSocket implementation for real-time dashboard updates.
2//!
3//! This module provides WebSocket functionality for real-time communication between
4//! the dashboard frontend and backend. It supports connection management, message
5//! broadcasting, and automatic ping/pong for connection health.
6//!
7//! # Message Types
8//!
9//! The WebSocket API supports several message types for different events:
10//!
11//! ```rust
12//! use hammerwork_web::websocket::{ClientMessage, ServerMessage, AlertSeverity};
13//! use serde_json::json;
14//!
15//! // Client messages (sent from browser to server)
16//! let subscribe_msg = ClientMessage::Subscribe {
17//!     event_types: vec!["queue_updates".to_string(), "job_updates".to_string()],
18//! };
19//!
20//! let ping_msg = ClientMessage::Ping;
21//!
22//! // Server messages (sent from server to browser)
23//! let alert_msg = ServerMessage::SystemAlert {
24//!     message: "High error rate detected".to_string(),
25//!     severity: AlertSeverity::Warning,
26//! };
27//!
28//! let pong_msg = ServerMessage::Pong;
29//! ```
30//!
31//! # Connection Management
32//!
33//! ```rust
34//! use hammerwork_web::websocket::WebSocketState;
35//! use hammerwork_web::config::WebSocketConfig;
36//!
37//! let config = WebSocketConfig::default();
38//! let ws_state = WebSocketState::new(config);
39//!
40//! assert_eq!(ws_state.connection_count(), 0);
41//! ```
42
43use crate::config::WebSocketConfig;
44use chrono::{DateTime, Utc};
45use futures_util::{SinkExt, StreamExt};
46pub use hammerwork::archive::{ArchivalReason, ArchivalStats};
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::sync::Arc;
50use tokio::sync::mpsc;
51use tracing::{debug, error, info, warn};
52use uuid::Uuid;
53use warp::ws::Message;
54
55/// WebSocket connection state manager
56#[derive(Debug)]
57pub struct WebSocketState {
58    config: WebSocketConfig,
59    connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
60    subscriptions: HashMap<Uuid, std::collections::HashSet<String>>,
61    broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
62    broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
63}
64
65impl WebSocketState {
66    pub fn new(config: WebSocketConfig) -> Self {
67        let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
68
69        Self {
70            config,
71            connections: HashMap::new(),
72            subscriptions: HashMap::new(),
73            broadcast_sender,
74            broadcast_receiver: Some(broadcast_receiver),
75        }
76    }
77
78    /// Handle a new WebSocket connection
79    pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
80        let connection_id = Uuid::new_v4();
81
82        if self.connections.len() >= self.config.max_connections {
83            warn!("Maximum WebSocket connections reached, rejecting new connection");
84            return Ok(());
85        }
86
87        let (mut ws_sender, mut ws_receiver) = websocket.split();
88        let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
89
90        // Store the connection
91        self.connections.insert(connection_id, tx);
92        info!("WebSocket connection established: {}", connection_id);
93
94        // Spawn task to handle outgoing messages to this client
95        let connection_id_clone = connection_id;
96        tokio::spawn(async move {
97            while let Some(message) = rx.recv().await {
98                if let Err(e) = ws_sender.send(message).await {
99                    debug!(
100                        "Failed to send WebSocket message to {}: {}",
101                        connection_id_clone, e
102                    );
103                    break;
104                }
105            }
106        });
107
108        // Handle incoming messages from this client
109        let broadcast_sender = self.broadcast_sender.clone();
110
111        while let Some(result) = ws_receiver.next().await {
112            match result {
113                Ok(message) => {
114                    if let Err(e) = self
115                        .handle_client_message(connection_id, message, &broadcast_sender)
116                        .await
117                    {
118                        error!(
119                            "Error handling client message from {}: {}",
120                            connection_id, e
121                        );
122                        break;
123                    }
124                }
125                Err(e) => {
126                    debug!("WebSocket error for connection {}: {}", connection_id, e);
127                    break;
128                }
129            }
130        }
131
132        // Clean up connection and subscriptions
133        self.connections.remove(&connection_id);
134        self.subscriptions.remove(&connection_id);
135        info!("WebSocket connection closed: {}", connection_id);
136
137        Ok(())
138    }
139
140    /// Handle a message from a client
141    async fn handle_client_message(
142        &mut self,
143        connection_id: Uuid,
144        message: Message,
145        _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
146    ) -> crate::Result<()> {
147        if message.is_text() {
148            if let Ok(text) = message.to_str() {
149                if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
150                    debug!(
151                        "Received message from {}: {:?}",
152                        connection_id, client_message
153                    );
154                    self.handle_client_action(connection_id, client_message)
155                        .await?;
156                } else {
157                    warn!("Invalid message format from {}: {}", connection_id, text);
158                }
159            }
160        } else if message.is_ping() {
161            // Send pong response
162            if let Some(sender) = self.connections.get(&connection_id) {
163                let pong_msg = Message::pong(message.as_bytes());
164                let _ = sender.send(pong_msg);
165            }
166        } else if message.is_pong() {
167            // Pong received - connection is alive
168            debug!("Pong received from {}", connection_id);
169        } else if message.is_close() {
170            debug!("Close message received from {}", connection_id);
171        } else if message.is_binary() {
172            warn!("Binary message not supported from {}", connection_id);
173        }
174
175        Ok(())
176    }
177
178    /// Handle a client action
179    async fn handle_client_action(
180        &mut self,
181        connection_id: Uuid,
182        message: ClientMessage,
183    ) -> crate::Result<()> {
184        match message {
185            ClientMessage::Subscribe { event_types } => {
186                info!(
187                    "Client {} subscribed to events: {:?}",
188                    connection_id, event_types
189                );
190                // Store subscription preferences per connection
191                let subscription_set = self
192                    .subscriptions
193                    .entry(connection_id)
194                    .or_insert_with(std::collections::HashSet::new);
195                for event_type in event_types {
196                    subscription_set.insert(event_type);
197                }
198            }
199            ClientMessage::Unsubscribe { event_types } => {
200                info!(
201                    "Client {} unsubscribed from events: {:?}",
202                    connection_id, event_types
203                );
204                // Update subscription preferences
205                if let Some(subscription_set) = self.subscriptions.get_mut(&connection_id) {
206                    for event_type in event_types {
207                        subscription_set.remove(&event_type);
208                    }
209                    // Remove empty subscription sets
210                    if subscription_set.is_empty() {
211                        self.subscriptions.remove(&connection_id);
212                    }
213                }
214            }
215            ClientMessage::Ping => {
216                // Client ping - we'll send a pong back via broadcast
217                self.broadcast_to_all(ServerMessage::Pong).await?;
218            }
219        }
220
221        Ok(())
222    }
223
224    /// Broadcast a message to all connected clients
225    pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
226        let json_message = serde_json::to_string(&message)?;
227        let ws_message = Message::text(json_message);
228
229        let mut disconnected = Vec::new();
230
231        for (&connection_id, sender) in &self.connections {
232            if sender.send(ws_message.clone()).is_err() {
233                disconnected.push(connection_id);
234            }
235        }
236
237        // Clean up disconnected clients
238        // Note: In a real implementation, we'd need mutable access to self.connections
239        // This would be handled by the connection cleanup in handle_connection
240
241        Ok(())
242    }
243
244    /// Broadcast a message to subscribed clients only
245    pub async fn broadcast_to_subscribed(
246        &self,
247        message: ServerMessage,
248        event_type: &str,
249    ) -> crate::Result<()> {
250        let json_message = serde_json::to_string(&message)?;
251        let ws_message = Message::text(json_message);
252
253        let mut disconnected = Vec::new();
254
255        for (&connection_id, sender) in &self.connections {
256            // Check if this connection is subscribed to this event type
257            if let Some(subscription_set) = self.subscriptions.get(&connection_id) {
258                if subscription_set.contains(event_type) {
259                    if sender.send(ws_message.clone()).is_err() {
260                        disconnected.push(connection_id);
261                    }
262                }
263            }
264        }
265
266        Ok(())
267    }
268
269    /// Publish an archive event to all connected clients
270    pub async fn publish_archive_event(
271        &self,
272        event: hammerwork::archive::ArchiveEvent,
273    ) -> crate::Result<()> {
274        let broadcast_message = match event {
275            hammerwork::archive::ArchiveEvent::JobArchived {
276                job_id,
277                queue,
278                reason,
279            } => BroadcastMessage::JobArchived {
280                job_id: job_id.to_string(),
281                queue,
282                reason,
283            },
284            hammerwork::archive::ArchiveEvent::JobRestored {
285                job_id,
286                queue,
287                restored_by,
288            } => BroadcastMessage::JobRestored {
289                job_id: job_id.to_string(),
290                queue,
291                restored_by,
292            },
293            hammerwork::archive::ArchiveEvent::BulkArchiveStarted {
294                operation_id,
295                estimated_jobs,
296            } => BroadcastMessage::BulkArchiveStarted {
297                operation_id,
298                estimated_jobs,
299            },
300            hammerwork::archive::ArchiveEvent::BulkArchiveProgress {
301                operation_id,
302                jobs_processed,
303                total,
304            } => BroadcastMessage::BulkArchiveProgress {
305                operation_id,
306                jobs_processed,
307                total,
308            },
309            hammerwork::archive::ArchiveEvent::BulkArchiveCompleted {
310                operation_id,
311                stats,
312            } => BroadcastMessage::BulkArchiveCompleted {
313                operation_id,
314                stats,
315            },
316            hammerwork::archive::ArchiveEvent::JobsPurged { count, older_than } => {
317                BroadcastMessage::JobsPurged { count, older_than }
318            }
319        };
320
321        // Send to the broadcast channel
322        if let Err(_) = self.broadcast_sender.send(broadcast_message) {
323            return Err(anyhow::anyhow!(
324                "Failed to send archive event to broadcast channel"
325            ));
326        }
327
328        Ok(())
329    }
330
331    /// Send ping to all connections to keep them alive
332    pub async fn ping_all_connections(&self) {
333        let ping_message = Message::ping(b"ping");
334        let mut disconnected = Vec::new();
335
336        for (&connection_id, sender) in &self.connections {
337            if sender.send(ping_message.clone()).is_err() {
338                disconnected.push(connection_id);
339            }
340        }
341
342        if !disconnected.is_empty() {
343            debug!(
344                "Detected {} disconnected WebSocket clients during ping",
345                disconnected.len()
346            );
347        }
348    }
349
350    /// Get current connection count
351    pub fn connection_count(&self) -> usize {
352        self.connections.len()
353    }
354
355    /// Start the broadcast listener task
356    pub async fn start_broadcast_listener(
357        state: Arc<tokio::sync::RwLock<WebSocketState>>,
358    ) -> crate::Result<()> {
359        let mut state_guard = state.write().await;
360        if let Some(mut receiver) = state_guard.broadcast_receiver.take() {
361            drop(state_guard); // Release the lock before spawning the task
362
363            tokio::spawn(async move {
364                while let Some(broadcast_message) = receiver.recv().await {
365                    // Determine the event type for subscription filtering
366                    let event_type = match &broadcast_message {
367                        BroadcastMessage::QueueUpdate { .. } => "queue_updates",
368                        BroadcastMessage::JobUpdate { .. } => "job_updates",
369                        BroadcastMessage::SystemAlert { .. } => "system_alerts",
370                        BroadcastMessage::JobArchived { .. } => "archive_events",
371                        BroadcastMessage::JobRestored { .. } => "archive_events",
372                        BroadcastMessage::BulkArchiveStarted { .. } => "archive_events",
373                        BroadcastMessage::BulkArchiveProgress { .. } => "archive_events",
374                        BroadcastMessage::BulkArchiveCompleted { .. } => "archive_events",
375                        BroadcastMessage::JobsPurged { .. } => "archive_events",
376                    };
377
378                    // Convert broadcast message to server message
379                    let server_message = match broadcast_message {
380                        BroadcastMessage::QueueUpdate { queue_name, stats } => {
381                            ServerMessage::QueueUpdate { queue_name, stats }
382                        }
383                        BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
384                        BroadcastMessage::SystemAlert { message, severity } => {
385                            ServerMessage::SystemAlert { message, severity }
386                        }
387                        BroadcastMessage::JobArchived {
388                            job_id,
389                            queue,
390                            reason,
391                        } => ServerMessage::JobArchived {
392                            job_id,
393                            queue,
394                            reason,
395                        },
396                        BroadcastMessage::JobRestored {
397                            job_id,
398                            queue,
399                            restored_by,
400                        } => ServerMessage::JobRestored {
401                            job_id,
402                            queue,
403                            restored_by,
404                        },
405                        BroadcastMessage::BulkArchiveStarted {
406                            operation_id,
407                            estimated_jobs,
408                        } => ServerMessage::BulkArchiveStarted {
409                            operation_id,
410                            estimated_jobs,
411                        },
412                        BroadcastMessage::BulkArchiveProgress {
413                            operation_id,
414                            jobs_processed,
415                            total,
416                        } => ServerMessage::BulkArchiveProgress {
417                            operation_id,
418                            jobs_processed,
419                            total,
420                        },
421                        BroadcastMessage::BulkArchiveCompleted {
422                            operation_id,
423                            stats,
424                        } => ServerMessage::BulkArchiveCompleted {
425                            operation_id,
426                            stats,
427                        },
428                        BroadcastMessage::JobsPurged { count, older_than } => {
429                            ServerMessage::JobsPurged { count, older_than }
430                        }
431                    };
432
433                    // Actually broadcast the message to subscribed clients
434                    let state_read = state.read().await;
435                    if let Err(e) = state_read
436                        .broadcast_to_subscribed(server_message, event_type)
437                        .await
438                    {
439                        error!("Failed to broadcast message: {}", e);
440                    }
441                }
442            });
443        }
444        Ok(())
445    }
446}
447
448/// Messages sent from client to server
449#[derive(Debug, Deserialize)]
450#[serde(tag = "type")]
451pub enum ClientMessage {
452    Subscribe { event_types: Vec<String> },
453    Unsubscribe { event_types: Vec<String> },
454    Ping,
455}
456
457/// Messages sent from server to client
458#[derive(Debug, Serialize)]
459#[serde(tag = "type")]
460pub enum ServerMessage {
461    QueueUpdate {
462        queue_name: String,
463        stats: QueueStats,
464    },
465    JobUpdate {
466        job: JobUpdate,
467    },
468    SystemAlert {
469        message: String,
470        severity: AlertSeverity,
471    },
472    JobArchived {
473        job_id: String,
474        queue: String,
475        reason: ArchivalReason,
476    },
477    JobRestored {
478        job_id: String,
479        queue: String,
480        restored_by: Option<String>,
481    },
482    BulkArchiveStarted {
483        operation_id: String,
484        estimated_jobs: u64,
485    },
486    BulkArchiveProgress {
487        operation_id: String,
488        jobs_processed: u64,
489        total: u64,
490    },
491    BulkArchiveCompleted {
492        operation_id: String,
493        stats: ArchivalStats,
494    },
495    JobsPurged {
496        count: u64,
497        older_than: DateTime<Utc>,
498    },
499    Pong,
500}
501
502/// Internal broadcast messages
503#[derive(Debug)]
504pub enum BroadcastMessage {
505    QueueUpdate {
506        queue_name: String,
507        stats: QueueStats,
508    },
509    JobUpdate {
510        job: JobUpdate,
511    },
512    SystemAlert {
513        message: String,
514        severity: AlertSeverity,
515    },
516    JobArchived {
517        job_id: String,
518        queue: String,
519        reason: ArchivalReason,
520    },
521    JobRestored {
522        job_id: String,
523        queue: String,
524        restored_by: Option<String>,
525    },
526    BulkArchiveStarted {
527        operation_id: String,
528        estimated_jobs: u64,
529    },
530    BulkArchiveProgress {
531        operation_id: String,
532        jobs_processed: u64,
533        total: u64,
534    },
535    BulkArchiveCompleted {
536        operation_id: String,
537        stats: ArchivalStats,
538    },
539    JobsPurged {
540        count: u64,
541        older_than: DateTime<Utc>,
542    },
543}
544
545/// Queue statistics for WebSocket updates
546#[derive(Debug, Serialize)]
547pub struct QueueStats {
548    pub pending_count: u64,
549    pub running_count: u64,
550    pub completed_count: u64,
551    pub failed_count: u64,
552    pub dead_count: u64,
553    pub throughput_per_minute: f64,
554    pub avg_processing_time_ms: f64,
555    pub error_rate: f64,
556    pub updated_at: chrono::DateTime<chrono::Utc>,
557}
558
559/// Job update information
560#[derive(Debug, Serialize)]
561pub struct JobUpdate {
562    pub id: String,
563    pub queue_name: String,
564    pub status: String,
565    pub priority: String,
566    pub attempts: i32,
567    pub updated_at: chrono::DateTime<chrono::Utc>,
568}
569
570/// Alert severity levels
571#[derive(Debug, Serialize)]
572pub enum AlertSeverity {
573    Info,
574    Warning,
575    Error,
576    Critical,
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582    use crate::config::WebSocketConfig;
583
584    #[test]
585    fn test_websocket_state_creation() {
586        let config = WebSocketConfig::default();
587        let state = WebSocketState::new(config);
588        assert_eq!(state.connection_count(), 0);
589    }
590
591    #[test]
592    fn test_client_message_deserialization() {
593        let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
594        let message: ClientMessage = serde_json::from_str(json).unwrap();
595
596        match message {
597            ClientMessage::Subscribe { event_types } => {
598                assert_eq!(event_types.len(), 2);
599                assert!(event_types.contains(&"queue_updates".to_string()));
600            }
601            _ => panic!("Wrong message type"),
602        }
603    }
604
605    #[test]
606    fn test_server_message_serialization() {
607        let message = ServerMessage::SystemAlert {
608            message: "High error rate detected".to_string(),
609            severity: AlertSeverity::Warning,
610        };
611
612        let json = serde_json::to_string(&message).unwrap();
613        assert!(json.contains("type"));
614        assert!(json.contains("SystemAlert"));
615        assert!(json.contains("High error rate detected"));
616    }
617
618    #[tokio::test]
619    async fn test_broadcast_to_all() {
620        let config = WebSocketConfig::default();
621        let state = WebSocketState::new(config);
622
623        let message = ServerMessage::Pong;
624        let result = state.broadcast_to_all(message).await;
625        assert!(result.is_ok());
626    }
627}