Skip to main content

llm_manager/backend/
ws_server.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{Context, Result, anyhow};
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use axum::{Router, response::Html, routing::get};
9use futures_util::{SinkExt, StreamExt};
10use tokio::sync::broadcast;
11use tokio::task::JoinHandle;
12use tower_http::trace::TraceLayer;
13use tracing::{error, info, warn};
14
15use crate::models::WsMetrics;
16
17#[derive(Clone)]
18pub struct WsAppState {
19    pub metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
20    pub auth_key: Option<String>,
21}
22
23pub async fn start_ws_server(
24    port: u16,
25    metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
26    auth_key: Option<String>,
27    tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
28    host: String,
29) -> Result<JoinHandle<()>> {
30    let state = WsAppState {
31        metrics_rx,
32        auth_key,
33    };
34
35    let app = Router::new()
36        .route("/dashboard", get(serve_dashboard))
37        .route("/ws", get(ws_handler))
38        .route("/health", get(|| async { "OK" }))
39        .layer(TraceLayer::new_for_http())
40        .with_state(state);
41
42    let addr = format!("{host}:{port}");
43
44    match tls_config {
45        Some(tls_cfg) => {
46            let socket_addr: std::net::SocketAddr = addr
47                .parse()
48                .map_err(|e| anyhow!("Invalid bind address {addr} for TLS: {e}"))?;
49            let tls_listener = axum_server::bind_rustls(socket_addr, tls_cfg);
50            let handle = tokio::spawn(async move {
51                if let Err(e) = tls_listener.serve(app.into_make_service()).await {
52                    error!("WebSocket server error: {e}");
53                }
54            });
55            info!("WebSocket server listening on https://{addr}");
56            Ok(handle)
57        }
58        None => {
59            let listener = tokio::net::TcpListener::bind(&addr)
60                .await
61                .with_context(|| format!("Failed to bind WebSocket server to {addr}"))?;
62            let handle = tokio::spawn(async move {
63                if let Err(e) = axum::serve(listener, app).await {
64                    error!("WebSocket server error: {e}");
65                }
66            });
67            info!("WebSocket server listening on http://{addr}");
68            Ok(handle)
69        }
70    }
71}
72
73pub fn stop_ws_server(handle: JoinHandle<()>) {
74    handle.abort();
75}
76
77async fn serve_dashboard(
78    axum::extract::State(state): axum::extract::State<WsAppState>,
79) -> Html<String> {
80    let auth_json = serde_json::to_string(&state.auth_key).unwrap_or("null".to_string());
81    let auth_script = format!("<script>window.__WS_AUTH={};</script>", auth_json);
82    let html = include_str!("../dashboard.html");
83    Html(html.replacen("</body>", &format!("{}\n</body>", auth_script), 1))
84}
85
86async fn ws_handler(
87    ws: WebSocketUpgrade,
88    axum::extract::State(state): axum::extract::State<WsAppState>,
89    axum::extract::Query(query): axum::extract::Query<HashMap<String, String>>,
90) -> impl IntoResponse {
91    if let Some(ref expected) = state.auth_key {
92        if let Some(provided) = query.get("auth").and_then(|v| urlencoding::decode(v).ok()) {
93            if provided != *expected {
94                return StatusCode::UNAUTHORIZED.into_response();
95            }
96        } else {
97            return StatusCode::UNAUTHORIZED.into_response();
98        }
99    }
100    ws.on_upgrade(move |socket| handle_socket(socket, state))
101}
102
103async fn handle_socket(socket: WebSocket, state: WsAppState) {
104    let mut rx = state.metrics_rx.resubscribe();
105    info!("WebSocket client connected");
106
107    let (mut sender, mut receiver) = socket.split();
108
109    loop {
110        tokio::select! {
111            biased;
112            _ = receiver.next() => {
113                info!("WebSocket client disconnected");
114                break;
115            }
116            metrics = rx.recv() => match metrics {
117                Ok(m) => {
118                    let json = match serde_json::to_string(&m) {
119                        Ok(j) => j,
120                        Err(e) => {
121                            error!("Failed to serialize metrics: {e}");
122                            continue;
123                        }
124                    };
125                    if sender.send(Message::Text(json.into())).await.is_err() {
126                        info!("WebSocket client disconnected");
127                        break;
128                    }
129                }
130                Err(broadcast::error::RecvError::Lagged(n)) => {
131                    warn!("WebSocket client lagged behind, skipped {n} metrics");
132                }
133                Err(broadcast::error::RecvError::Closed) => {
134                    break;
135                }
136            },
137        }
138    }
139}