1use crate::config::WebSocketConfig;
44use chrono::{DateTime, Utc};
45use futures_util::{SinkExt, StreamExt};
46pub use hammerwork::archive::{ArchivalReason, ArchivalStats};
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::sync::Arc;
50use tokio::sync::mpsc;
51use tracing::{debug, error, info, warn};
52use uuid::Uuid;
53use warp::ws::Message;
54
55#[derive(Debug)]
57pub struct WebSocketState {
58 config: WebSocketConfig,
59 connections: HashMap<Uuid, mpsc::UnboundedSender<Message>>,
60 subscriptions: HashMap<Uuid, std::collections::HashSet<String>>,
61 broadcast_sender: mpsc::UnboundedSender<BroadcastMessage>,
62 broadcast_receiver: Option<mpsc::UnboundedReceiver<BroadcastMessage>>,
63}
64
65impl WebSocketState {
66 pub fn new(config: WebSocketConfig) -> Self {
67 let (broadcast_sender, broadcast_receiver) = mpsc::unbounded_channel();
68
69 Self {
70 config,
71 connections: HashMap::new(),
72 subscriptions: HashMap::new(),
73 broadcast_sender,
74 broadcast_receiver: Some(broadcast_receiver),
75 }
76 }
77
78 pub async fn handle_connection(&mut self, websocket: warp::ws::WebSocket) -> crate::Result<()> {
80 let connection_id = Uuid::new_v4();
81
82 if self.connections.len() >= self.config.max_connections {
83 warn!("Maximum WebSocket connections reached, rejecting new connection");
84 return Ok(());
85 }
86
87 let (mut ws_sender, mut ws_receiver) = websocket.split();
88 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
89
90 self.connections.insert(connection_id, tx);
92 info!("WebSocket connection established: {}", connection_id);
93
94 let connection_id_clone = connection_id;
96 tokio::spawn(async move {
97 while let Some(message) = rx.recv().await {
98 if let Err(e) = ws_sender.send(message).await {
99 debug!(
100 "Failed to send WebSocket message to {}: {}",
101 connection_id_clone, e
102 );
103 break;
104 }
105 }
106 });
107
108 let broadcast_sender = self.broadcast_sender.clone();
110
111 while let Some(result) = ws_receiver.next().await {
112 match result {
113 Ok(message) => {
114 if let Err(e) = self
115 .handle_client_message(connection_id, message, &broadcast_sender)
116 .await
117 {
118 error!(
119 "Error handling client message from {}: {}",
120 connection_id, e
121 );
122 break;
123 }
124 }
125 Err(e) => {
126 debug!("WebSocket error for connection {}: {}", connection_id, e);
127 break;
128 }
129 }
130 }
131
132 self.connections.remove(&connection_id);
134 self.subscriptions.remove(&connection_id);
135 info!("WebSocket connection closed: {}", connection_id);
136
137 Ok(())
138 }
139
140 async fn handle_client_message(
142 &mut self,
143 connection_id: Uuid,
144 message: Message,
145 _broadcast_sender: &mpsc::UnboundedSender<BroadcastMessage>,
146 ) -> crate::Result<()> {
147 if message.is_text() {
148 if let Ok(text) = message.to_str() {
149 if let Ok(client_message) = serde_json::from_str::<ClientMessage>(text) {
150 debug!(
151 "Received message from {}: {:?}",
152 connection_id, client_message
153 );
154 self.handle_client_action(connection_id, client_message)
155 .await?;
156 } else {
157 warn!("Invalid message format from {}: {}", connection_id, text);
158 }
159 }
160 } else if message.is_ping() {
161 if let Some(sender) = self.connections.get(&connection_id) {
163 let pong_msg = Message::pong(message.as_bytes());
164 let _ = sender.send(pong_msg);
165 }
166 } else if message.is_pong() {
167 debug!("Pong received from {}", connection_id);
169 } else if message.is_close() {
170 debug!("Close message received from {}", connection_id);
171 } else if message.is_binary() {
172 warn!("Binary message not supported from {}", connection_id);
173 }
174
175 Ok(())
176 }
177
178 async fn handle_client_action(
180 &mut self,
181 connection_id: Uuid,
182 message: ClientMessage,
183 ) -> crate::Result<()> {
184 match message {
185 ClientMessage::Subscribe { event_types } => {
186 info!(
187 "Client {} subscribed to events: {:?}",
188 connection_id, event_types
189 );
190 let subscription_set = self
192 .subscriptions
193 .entry(connection_id)
194 .or_insert_with(std::collections::HashSet::new);
195 for event_type in event_types {
196 subscription_set.insert(event_type);
197 }
198 }
199 ClientMessage::Unsubscribe { event_types } => {
200 info!(
201 "Client {} unsubscribed from events: {:?}",
202 connection_id, event_types
203 );
204 if let Some(subscription_set) = self.subscriptions.get_mut(&connection_id) {
206 for event_type in event_types {
207 subscription_set.remove(&event_type);
208 }
209 if subscription_set.is_empty() {
211 self.subscriptions.remove(&connection_id);
212 }
213 }
214 }
215 ClientMessage::Ping => {
216 self.broadcast_to_all(ServerMessage::Pong).await?;
218 }
219 }
220
221 Ok(())
222 }
223
224 pub async fn broadcast_to_all(&self, message: ServerMessage) -> crate::Result<()> {
226 let json_message = serde_json::to_string(&message)?;
227 let ws_message = Message::text(json_message);
228
229 let mut disconnected = Vec::new();
230
231 for (&connection_id, sender) in &self.connections {
232 if sender.send(ws_message.clone()).is_err() {
233 disconnected.push(connection_id);
234 }
235 }
236
237 Ok(())
242 }
243
244 pub async fn broadcast_to_subscribed(
246 &self,
247 message: ServerMessage,
248 event_type: &str,
249 ) -> crate::Result<()> {
250 let json_message = serde_json::to_string(&message)?;
251 let ws_message = Message::text(json_message);
252
253 let mut disconnected = Vec::new();
254
255 for (&connection_id, sender) in &self.connections {
256 if let Some(subscription_set) = self.subscriptions.get(&connection_id) {
258 if subscription_set.contains(event_type) {
259 if sender.send(ws_message.clone()).is_err() {
260 disconnected.push(connection_id);
261 }
262 }
263 }
264 }
265
266 Ok(())
267 }
268
269 pub async fn publish_archive_event(
271 &self,
272 event: hammerwork::archive::ArchiveEvent,
273 ) -> crate::Result<()> {
274 let broadcast_message = match event {
275 hammerwork::archive::ArchiveEvent::JobArchived {
276 job_id,
277 queue,
278 reason,
279 } => BroadcastMessage::JobArchived {
280 job_id: job_id.to_string(),
281 queue,
282 reason,
283 },
284 hammerwork::archive::ArchiveEvent::JobRestored {
285 job_id,
286 queue,
287 restored_by,
288 } => BroadcastMessage::JobRestored {
289 job_id: job_id.to_string(),
290 queue,
291 restored_by,
292 },
293 hammerwork::archive::ArchiveEvent::BulkArchiveStarted {
294 operation_id,
295 estimated_jobs,
296 } => BroadcastMessage::BulkArchiveStarted {
297 operation_id,
298 estimated_jobs,
299 },
300 hammerwork::archive::ArchiveEvent::BulkArchiveProgress {
301 operation_id,
302 jobs_processed,
303 total,
304 } => BroadcastMessage::BulkArchiveProgress {
305 operation_id,
306 jobs_processed,
307 total,
308 },
309 hammerwork::archive::ArchiveEvent::BulkArchiveCompleted {
310 operation_id,
311 stats,
312 } => BroadcastMessage::BulkArchiveCompleted {
313 operation_id,
314 stats,
315 },
316 hammerwork::archive::ArchiveEvent::JobsPurged { count, older_than } => {
317 BroadcastMessage::JobsPurged { count, older_than }
318 }
319 };
320
321 if let Err(_) = self.broadcast_sender.send(broadcast_message) {
323 return Err(anyhow::anyhow!(
324 "Failed to send archive event to broadcast channel"
325 ));
326 }
327
328 Ok(())
329 }
330
331 pub async fn ping_all_connections(&self) {
333 let ping_message = Message::ping(b"ping");
334 let mut disconnected = Vec::new();
335
336 for (&connection_id, sender) in &self.connections {
337 if sender.send(ping_message.clone()).is_err() {
338 disconnected.push(connection_id);
339 }
340 }
341
342 if !disconnected.is_empty() {
343 debug!(
344 "Detected {} disconnected WebSocket clients during ping",
345 disconnected.len()
346 );
347 }
348 }
349
350 pub fn connection_count(&self) -> usize {
352 self.connections.len()
353 }
354
355 pub async fn start_broadcast_listener(
357 state: Arc<tokio::sync::RwLock<WebSocketState>>,
358 ) -> crate::Result<()> {
359 let mut state_guard = state.write().await;
360 if let Some(mut receiver) = state_guard.broadcast_receiver.take() {
361 drop(state_guard); tokio::spawn(async move {
364 while let Some(broadcast_message) = receiver.recv().await {
365 let event_type = match &broadcast_message {
367 BroadcastMessage::QueueUpdate { .. } => "queue_updates",
368 BroadcastMessage::JobUpdate { .. } => "job_updates",
369 BroadcastMessage::SystemAlert { .. } => "system_alerts",
370 BroadcastMessage::JobArchived { .. } => "archive_events",
371 BroadcastMessage::JobRestored { .. } => "archive_events",
372 BroadcastMessage::BulkArchiveStarted { .. } => "archive_events",
373 BroadcastMessage::BulkArchiveProgress { .. } => "archive_events",
374 BroadcastMessage::BulkArchiveCompleted { .. } => "archive_events",
375 BroadcastMessage::JobsPurged { .. } => "archive_events",
376 };
377
378 let server_message = match broadcast_message {
380 BroadcastMessage::QueueUpdate { queue_name, stats } => {
381 ServerMessage::QueueUpdate { queue_name, stats }
382 }
383 BroadcastMessage::JobUpdate { job } => ServerMessage::JobUpdate { job },
384 BroadcastMessage::SystemAlert { message, severity } => {
385 ServerMessage::SystemAlert { message, severity }
386 }
387 BroadcastMessage::JobArchived {
388 job_id,
389 queue,
390 reason,
391 } => ServerMessage::JobArchived {
392 job_id,
393 queue,
394 reason,
395 },
396 BroadcastMessage::JobRestored {
397 job_id,
398 queue,
399 restored_by,
400 } => ServerMessage::JobRestored {
401 job_id,
402 queue,
403 restored_by,
404 },
405 BroadcastMessage::BulkArchiveStarted {
406 operation_id,
407 estimated_jobs,
408 } => ServerMessage::BulkArchiveStarted {
409 operation_id,
410 estimated_jobs,
411 },
412 BroadcastMessage::BulkArchiveProgress {
413 operation_id,
414 jobs_processed,
415 total,
416 } => ServerMessage::BulkArchiveProgress {
417 operation_id,
418 jobs_processed,
419 total,
420 },
421 BroadcastMessage::BulkArchiveCompleted {
422 operation_id,
423 stats,
424 } => ServerMessage::BulkArchiveCompleted {
425 operation_id,
426 stats,
427 },
428 BroadcastMessage::JobsPurged { count, older_than } => {
429 ServerMessage::JobsPurged { count, older_than }
430 }
431 };
432
433 let state_read = state.read().await;
435 if let Err(e) = state_read
436 .broadcast_to_subscribed(server_message, event_type)
437 .await
438 {
439 error!("Failed to broadcast message: {}", e);
440 }
441 }
442 });
443 }
444 Ok(())
445 }
446}
447
448#[derive(Debug, Deserialize)]
450#[serde(tag = "type")]
451pub enum ClientMessage {
452 Subscribe { event_types: Vec<String> },
453 Unsubscribe { event_types: Vec<String> },
454 Ping,
455}
456
457#[derive(Debug, Serialize)]
459#[serde(tag = "type")]
460pub enum ServerMessage {
461 QueueUpdate {
462 queue_name: String,
463 stats: QueueStats,
464 },
465 JobUpdate {
466 job: JobUpdate,
467 },
468 SystemAlert {
469 message: String,
470 severity: AlertSeverity,
471 },
472 JobArchived {
473 job_id: String,
474 queue: String,
475 reason: ArchivalReason,
476 },
477 JobRestored {
478 job_id: String,
479 queue: String,
480 restored_by: Option<String>,
481 },
482 BulkArchiveStarted {
483 operation_id: String,
484 estimated_jobs: u64,
485 },
486 BulkArchiveProgress {
487 operation_id: String,
488 jobs_processed: u64,
489 total: u64,
490 },
491 BulkArchiveCompleted {
492 operation_id: String,
493 stats: ArchivalStats,
494 },
495 JobsPurged {
496 count: u64,
497 older_than: DateTime<Utc>,
498 },
499 Pong,
500}
501
502#[derive(Debug)]
504pub enum BroadcastMessage {
505 QueueUpdate {
506 queue_name: String,
507 stats: QueueStats,
508 },
509 JobUpdate {
510 job: JobUpdate,
511 },
512 SystemAlert {
513 message: String,
514 severity: AlertSeverity,
515 },
516 JobArchived {
517 job_id: String,
518 queue: String,
519 reason: ArchivalReason,
520 },
521 JobRestored {
522 job_id: String,
523 queue: String,
524 restored_by: Option<String>,
525 },
526 BulkArchiveStarted {
527 operation_id: String,
528 estimated_jobs: u64,
529 },
530 BulkArchiveProgress {
531 operation_id: String,
532 jobs_processed: u64,
533 total: u64,
534 },
535 BulkArchiveCompleted {
536 operation_id: String,
537 stats: ArchivalStats,
538 },
539 JobsPurged {
540 count: u64,
541 older_than: DateTime<Utc>,
542 },
543}
544
545#[derive(Debug, Serialize)]
547pub struct QueueStats {
548 pub pending_count: u64,
549 pub running_count: u64,
550 pub completed_count: u64,
551 pub failed_count: u64,
552 pub dead_count: u64,
553 pub throughput_per_minute: f64,
554 pub avg_processing_time_ms: f64,
555 pub error_rate: f64,
556 pub updated_at: chrono::DateTime<chrono::Utc>,
557}
558
559#[derive(Debug, Serialize)]
561pub struct JobUpdate {
562 pub id: String,
563 pub queue_name: String,
564 pub status: String,
565 pub priority: String,
566 pub attempts: i32,
567 pub updated_at: chrono::DateTime<chrono::Utc>,
568}
569
570#[derive(Debug, Serialize)]
572pub enum AlertSeverity {
573 Info,
574 Warning,
575 Error,
576 Critical,
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use crate::config::WebSocketConfig;
583
584 #[test]
585 fn test_websocket_state_creation() {
586 let config = WebSocketConfig::default();
587 let state = WebSocketState::new(config);
588 assert_eq!(state.connection_count(), 0);
589 }
590
591 #[test]
592 fn test_client_message_deserialization() {
593 let json = r#"{"type": "Subscribe", "event_types": ["queue_updates", "job_updates"]}"#;
594 let message: ClientMessage = serde_json::from_str(json).unwrap();
595
596 match message {
597 ClientMessage::Subscribe { event_types } => {
598 assert_eq!(event_types.len(), 2);
599 assert!(event_types.contains(&"queue_updates".to_string()));
600 }
601 _ => panic!("Wrong message type"),
602 }
603 }
604
605 #[test]
606 fn test_server_message_serialization() {
607 let message = ServerMessage::SystemAlert {
608 message: "High error rate detected".to_string(),
609 severity: AlertSeverity::Warning,
610 };
611
612 let json = serde_json::to_string(&message).unwrap();
613 assert!(json.contains("type"));
614 assert!(json.contains("SystemAlert"));
615 assert!(json.contains("High error rate detected"));
616 }
617
618 #[tokio::test]
619 async fn test_broadcast_to_all() {
620 let config = WebSocketConfig::default();
621 let state = WebSocketState::new(config);
622
623 let message = ServerMessage::Pong;
624 let result = state.broadcast_to_all(message).await;
625 assert!(result.is_ok());
626 }
627}