qml_rs/dashboard/
websocket.rs1use axum::{
2 extract::{
3 State,
4 ws::{Message, WebSocket, WebSocketUpgrade},
5 },
6 response::Response,
7};
8use futures_util::{SinkExt, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::{RwLock, broadcast};
13use uuid::Uuid;
14
15use crate::dashboard::service::{DashboardService, ServerStatistics};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type")]
19pub enum DashboardMessage {
20 #[serde(rename = "statistics_update")]
21 StatisticsUpdate { data: ServerStatistics },
22
23 #[serde(rename = "job_update")]
24 JobUpdate {
25 job_id: String,
26 state: String,
27 updated_at: String,
28 },
29
30 #[serde(rename = "connection_info")]
31 ConnectionInfo {
32 client_id: String,
33 connected_clients: usize,
34 },
35
36 #[serde(rename = "error")]
37 Error { message: String },
38}
39
40#[derive(Debug, Clone)]
41pub struct WebSocketConnection {
42 pub id: String,
43 pub sender: broadcast::Sender<DashboardMessage>,
44}
45
46pub struct WebSocketManager {
48 connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
49 broadcast_sender: broadcast::Sender<DashboardMessage>,
50 dashboard_service: Arc<DashboardService>,
51}
52
53impl WebSocketManager {
54 pub fn new(dashboard_service: Arc<DashboardService>) -> Self {
55 let (broadcast_sender, _) = broadcast::channel(1000);
56
57 Self {
58 connections: Arc::new(RwLock::new(HashMap::new())),
59 broadcast_sender,
60 dashboard_service,
61 }
62 }
63
64 pub async fn handle_websocket(&self, ws: WebSocketUpgrade) -> Response {
66 let connections = Arc::clone(&self.connections);
67 let broadcast_sender = self.broadcast_sender.clone();
68 let dashboard_service = Arc::clone(&self.dashboard_service);
69
70 ws.on_upgrade(move |socket| {
71 Self::handle_socket(socket, connections, broadcast_sender, dashboard_service)
72 })
73 }
74
75 async fn handle_socket(
77 socket: WebSocket,
78 connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
79 broadcast_sender: broadcast::Sender<DashboardMessage>,
80 dashboard_service: Arc<DashboardService>,
81 ) {
82 let client_id = Uuid::new_v4().to_string();
83 let mut receiver = broadcast_sender.subscribe();
84
85 let (mut sender, mut socket_receiver) = socket.split();
87
88 {
90 let mut conns = connections.write().await;
91 conns.insert(
92 client_id.clone(),
93 WebSocketConnection {
94 id: client_id.clone(),
95 sender: broadcast_sender.clone(),
96 },
97 );
98 }
99
100 let connected_count = connections.read().await.len();
102 let connection_msg = DashboardMessage::ConnectionInfo {
103 client_id: client_id.clone(),
104 connected_clients: connected_count,
105 };
106
107 if let Ok(msg_str) = serde_json::to_string(&connection_msg) {
108 let _ = sender.send(Message::Text(msg_str.into())).await;
109 }
110
111 if let Ok(stats) = dashboard_service.get_server_statistics().await {
113 let stats_msg = DashboardMessage::StatisticsUpdate { data: stats };
114 if let Ok(msg_str) = serde_json::to_string(&stats_msg) {
115 let _ = sender.send(Message::Text(msg_str.into())).await;
116 }
117 }
118
119 let connections_clone = Arc::clone(&connections);
121 let client_id_clone = client_id.clone();
122
123 tokio::select! {
124 _ = async {
126 while let Some(msg) = socket_receiver.next().await {
127 match msg {
128 Ok(Message::Text(text)) => {
129 tracing::debug!("Received WebSocket message: {}", text);
130 }
132 Ok(Message::Close(_)) => {
133 tracing::info!("WebSocket connection closed by client: {}", client_id);
134 break;
135 }
136 Err(e) => {
137 tracing::error!("WebSocket error: {}", e);
138 break;
139 }
140 _ => {}
141 }
142 }
143 } => {}
144
145 _ = async {
147 while let Ok(msg) = receiver.recv().await {
148 if let Ok(msg_str) = serde_json::to_string(&msg) {
149 if sender.send(Message::Text(msg_str.into())).await.is_err() {
150 tracing::info!("Failed to send message to client {}, removing connection", client_id);
151 break;
152 }
153 }
154 }
155 } => {}
156 }
157
158 {
160 let mut conns = connections_clone.write().await;
161 conns.remove(&client_id_clone);
162 }
163
164 tracing::info!("WebSocket connection {} disconnected", client_id_clone);
165 }
166
167 pub async fn broadcast_statistics_update(
169 &self,
170 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
171 let stats = self.dashboard_service.get_server_statistics().await?;
172 let msg = DashboardMessage::StatisticsUpdate { data: stats };
173
174 let _ = self.broadcast_sender.send(msg);
175 Ok(())
176 }
177
178 pub async fn broadcast_job_update(&self, job_id: String, state: String, updated_at: String) {
180 let msg = DashboardMessage::JobUpdate {
181 job_id,
182 state,
183 updated_at,
184 };
185
186 let _ = self.broadcast_sender.send(msg);
187 }
188
189 pub async fn connected_clients_count(&self) -> usize {
191 self.connections.read().await.len()
192 }
193
194 pub async fn start_periodic_updates(&self, interval_seconds: u64) {
196 let broadcast_sender = self.broadcast_sender.clone();
197 let dashboard_service = Arc::clone(&self.dashboard_service);
198
199 tokio::spawn(async move {
200 let mut interval =
201 tokio::time::interval(tokio::time::Duration::from_secs(interval_seconds));
202
203 loop {
204 interval.tick().await;
205
206 if let Ok(stats) = dashboard_service.get_server_statistics().await {
207 let msg = DashboardMessage::StatisticsUpdate { data: stats };
208 let _ = broadcast_sender.send(msg);
209 } else {
210 tracing::error!("Failed to get statistics for periodic update");
211 }
212 }
213 });
214 }
215}
216
217pub async fn websocket_handler(
219 ws: WebSocketUpgrade,
220 State(ws_manager): State<Arc<WebSocketManager>>,
221) -> Response {
222 ws_manager.handle_websocket(ws).await
223}