use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast};
use uuid::Uuid;
use crate::dashboard::service::{DashboardService, ServerStatistics};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum DashboardMessage {
#[serde(rename = "statistics_update")]
StatisticsUpdate { data: ServerStatistics },
#[serde(rename = "job_update")]
JobUpdate {
job_id: String,
state: String,
updated_at: String,
},
#[serde(rename = "connection_info")]
ConnectionInfo {
client_id: String,
connected_clients: usize,
},
#[serde(rename = "error")]
Error { message: String },
}
#[derive(Debug, Clone)]
pub struct WebSocketConnection {
pub id: String,
pub sender: broadcast::Sender<DashboardMessage>,
}
pub struct WebSocketManager {
connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
broadcast_sender: broadcast::Sender<DashboardMessage>,
dashboard_service: Arc<DashboardService>,
}
impl WebSocketManager {
pub fn new(dashboard_service: Arc<DashboardService>) -> Self {
let (broadcast_sender, _) = broadcast::channel(1000);
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
broadcast_sender,
dashboard_service,
}
}
pub async fn handle_websocket(&self, ws: WebSocketUpgrade) -> Response {
let connections = Arc::clone(&self.connections);
let broadcast_sender = self.broadcast_sender.clone();
let dashboard_service = Arc::clone(&self.dashboard_service);
ws.on_upgrade(move |socket| {
Self::handle_socket(socket, connections, broadcast_sender, dashboard_service)
})
}
async fn handle_socket(
socket: WebSocket,
connections: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
broadcast_sender: broadcast::Sender<DashboardMessage>,
dashboard_service: Arc<DashboardService>,
) {
let client_id = Uuid::new_v4().to_string();
let mut receiver = broadcast_sender.subscribe();
let (mut sender, mut socket_receiver) = socket.split();
{
let mut conns = connections.write().await;
conns.insert(
client_id.clone(),
WebSocketConnection {
id: client_id.clone(),
sender: broadcast_sender.clone(),
},
);
}
let connected_count = connections.read().await.len();
let connection_msg = DashboardMessage::ConnectionInfo {
client_id: client_id.clone(),
connected_clients: connected_count,
};
if let Ok(msg_str) = serde_json::to_string(&connection_msg) {
let _ = sender.send(Message::Text(msg_str.into())).await;
}
if let Ok(stats) = dashboard_service.get_server_statistics().await {
let stats_msg = DashboardMessage::StatisticsUpdate { data: stats };
if let Ok(msg_str) = serde_json::to_string(&stats_msg) {
let _ = sender.send(Message::Text(msg_str.into())).await;
}
}
let connections_clone = Arc::clone(&connections);
let client_id_clone = client_id.clone();
tokio::select! {
_ = async {
while let Some(msg) = socket_receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
tracing::debug!("Received WebSocket message: {}", text);
}
Ok(Message::Close(_)) => {
tracing::info!("WebSocket connection closed by client: {}", client_id);
break;
}
Err(e) => {
tracing::error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
} => {}
_ = async {
while let Ok(msg) = receiver.recv().await {
if let Ok(msg_str) = serde_json::to_string(&msg)
&& sender.send(Message::Text(msg_str.into())).await.is_err()
{
tracing::info!("Failed to send message to client {}, removing connection", client_id);
break;
}
}
} => {}
}
{
let mut conns = connections_clone.write().await;
conns.remove(&client_id_clone);
}
tracing::info!("WebSocket connection {} disconnected", client_id_clone);
}
pub async fn broadcast_statistics_update(
&self,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let stats = self.dashboard_service.get_server_statistics().await?;
let msg = DashboardMessage::StatisticsUpdate { data: stats };
let _ = self.broadcast_sender.send(msg);
Ok(())
}
pub async fn broadcast_job_update(&self, job_id: String, state: String, updated_at: String) {
let msg = DashboardMessage::JobUpdate {
job_id,
state,
updated_at,
};
let _ = self.broadcast_sender.send(msg);
}
pub async fn connected_clients_count(&self) -> usize {
self.connections.read().await.len()
}
pub async fn start_periodic_updates(&self, interval_seconds: u64) {
let broadcast_sender = self.broadcast_sender.clone();
let dashboard_service = Arc::clone(&self.dashboard_service);
tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(interval_seconds));
loop {
interval.tick().await;
if let Ok(stats) = dashboard_service.get_server_statistics().await {
let msg = DashboardMessage::StatisticsUpdate { data: stats };
let _ = broadcast_sender.send(msg);
} else {
tracing::error!("Failed to get statistics for periodic update");
}
}
});
}
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(ws_manager): State<Arc<WebSocketManager>>,
) -> Response {
ws_manager.handle_websocket(ws).await
}