qml_rs/dashboard/
websocket.rs

1use axum::{
2    extract::{
3        State,
4        ws::{Message, WebSocket, WebSocketUpgrade},
5    },
6    response::Response,
7};
8use futures_util::{SinkExt, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::{RwLock, broadcast};
13use uuid::Uuid;
14
15use crate::dashboard::service::{DashboardService, ServerStatistics};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type")]
19pub enum DashboardMessage {
20    #[serde(rename = "statistics_update")]
21    StatisticsUpdate { data: ServerStatistics },
22
23    #[serde(rename = "job_update")]
24    JobUpdate {
25        job_id: String,
26        state: String,
27        updated_at: String,
28    },
29
30    #[serde(rename = "connection_info")]
31    ConnectionInfo {
32        client_id: String,
33        connected_clients: usize,
34    },
35
36    #[serde(rename = "error")]
37    Error { message: String },
38}
39
40#[derive(Debug, Clone)]
41pub struct WebSocketConnection {
42    pub id: String,
43    pub sender: broadcast::Sender<DashboardMessage>,
44}
45
46/// WebSocket manager for handling real-time dashboard connections
47pub struct WebSocketManager {
48    connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
49    broadcast_sender: broadcast::Sender<DashboardMessage>,
50    dashboard_service: Arc<DashboardService>,
51}
52
53impl WebSocketManager {
54    pub fn new(dashboard_service: Arc<DashboardService>) -> Self {
55        let (broadcast_sender, _) = broadcast::channel(1000);
56
57        Self {
58            connections: Arc::new(RwLock::new(HashMap::new())),
59            broadcast_sender,
60            dashboard_service,
61        }
62    }
63
64    /// Handle new WebSocket connection
65    pub async fn handle_websocket(&self, ws: WebSocketUpgrade) -> Response {
66        let connections = Arc::clone(&self.connections);
67        let broadcast_sender = self.broadcast_sender.clone();
68        let dashboard_service = Arc::clone(&self.dashboard_service);
69
70        ws.on_upgrade(move |socket| {
71            Self::handle_socket(socket, connections, broadcast_sender, dashboard_service)
72        })
73    }
74
75    /// Handle individual WebSocket connection
76    async fn handle_socket(
77        socket: WebSocket,
78        connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
79        broadcast_sender: broadcast::Sender<DashboardMessage>,
80        dashboard_service: Arc<DashboardService>,
81    ) {
82        let client_id = Uuid::new_v4().to_string();
83        let mut receiver = broadcast_sender.subscribe();
84
85        // Split the socket into sender and receiver
86        let (mut sender, mut socket_receiver) = socket.split();
87
88        // Add connection to the manager
89        {
90            let mut conns = connections.write().await;
91            conns.insert(
92                client_id.clone(),
93                WebSocketConnection {
94                    id: client_id.clone(),
95                    sender: broadcast_sender.clone(),
96                },
97            );
98        }
99
100        // Send initial connection info
101        let connected_count = connections.read().await.len();
102        let connection_msg = DashboardMessage::ConnectionInfo {
103            client_id: client_id.clone(),
104            connected_clients: connected_count,
105        };
106
107        if let Ok(msg_str) = serde_json::to_string(&connection_msg) {
108            let _ = sender.send(Message::Text(msg_str.into())).await;
109        }
110
111        // Send initial statistics
112        if let Ok(stats) = dashboard_service.get_server_statistics().await {
113            let stats_msg = DashboardMessage::StatisticsUpdate { data: stats };
114            if let Ok(msg_str) = serde_json::to_string(&stats_msg) {
115                let _ = sender.send(Message::Text(msg_str.into())).await;
116            }
117        }
118
119        // Handle incoming messages and broadcast updates
120        let connections_clone = Arc::clone(&connections);
121        let client_id_clone = client_id.clone();
122
123        tokio::select! {
124            // Handle incoming WebSocket messages
125            _ = async {
126                while let Some(msg) = socket_receiver.next().await {
127                    match msg {
128                        Ok(Message::Text(text)) => {
129                            tracing::debug!("Received WebSocket message: {}", text);
130                            // Handle client messages if needed (ping/pong, etc.)
131                        }
132                        Ok(Message::Close(_)) => {
133                            tracing::info!("WebSocket connection closed by client: {}", client_id);
134                            break;
135                        }
136                        Err(e) => {
137                            tracing::error!("WebSocket error: {}", e);
138                            break;
139                        }
140                        _ => {}
141                    }
142                }
143            } => {}
144
145            // Handle broadcast messages
146            _ = async {
147                while let Ok(msg) = receiver.recv().await {
148                    if let Ok(msg_str) = serde_json::to_string(&msg) {
149                        if sender.send(Message::Text(msg_str.into())).await.is_err() {
150                            tracing::info!("Failed to send message to client {}, removing connection", client_id);
151                            break;
152                        }
153                    }
154                }
155            } => {}
156        }
157
158        // Remove connection when done
159        {
160            let mut conns = connections_clone.write().await;
161            conns.remove(&client_id_clone);
162        }
163
164        tracing::info!("WebSocket connection {} disconnected", client_id_clone);
165    }
166
167    /// Broadcast statistics update to all connected clients
168    pub async fn broadcast_statistics_update(
169        &self,
170    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
171        let stats = self.dashboard_service.get_server_statistics().await?;
172        let msg = DashboardMessage::StatisticsUpdate { data: stats };
173
174        let _ = self.broadcast_sender.send(msg);
175        Ok(())
176    }
177
178    /// Broadcast job update to all connected clients
179    pub async fn broadcast_job_update(&self, job_id: String, state: String, updated_at: String) {
180        let msg = DashboardMessage::JobUpdate {
181            job_id,
182            state,
183            updated_at,
184        };
185
186        let _ = self.broadcast_sender.send(msg);
187    }
188
189    /// Get number of connected clients
190    pub async fn connected_clients_count(&self) -> usize {
191        self.connections.read().await.len()
192    }
193
194    /// Start periodic statistics broadcast
195    pub async fn start_periodic_updates(&self, interval_seconds: u64) {
196        let broadcast_sender = self.broadcast_sender.clone();
197        let dashboard_service = Arc::clone(&self.dashboard_service);
198
199        tokio::spawn(async move {
200            let mut interval =
201                tokio::time::interval(tokio::time::Duration::from_secs(interval_seconds));
202
203            loop {
204                interval.tick().await;
205
206                if let Ok(stats) = dashboard_service.get_server_statistics().await {
207                    let msg = DashboardMessage::StatisticsUpdate { data: stats };
208                    let _ = broadcast_sender.send(msg);
209                } else {
210                    tracing::error!("Failed to get statistics for periodic update");
211                }
212            }
213        });
214    }
215}
216
217/// Create WebSocket handler function for axum router
218pub async fn websocket_handler(
219    ws: WebSocketUpgrade,
220    State(ws_manager): State<Arc<WebSocketManager>>,
221) -> Response {
222    ws_manager.handle_websocket(ws).await
223}