llm_manager/backend/
ws_server.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{Context, Result, anyhow};
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use axum::{Router, response::Html, routing::get};
9use futures_util::{SinkExt, StreamExt};
10use tokio::sync::broadcast;
11use tokio::task::JoinHandle;
12use tower_http::trace::TraceLayer;
13use tracing::{error, info, warn};
14
15use crate::models::WsMetrics;
16
17#[derive(Clone)]
18pub struct WsAppState {
19 pub metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
20 pub auth_key: Option<String>,
21}
22
23pub async fn start_ws_server(
24 port: u16,
25 metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
26 auth_key: Option<String>,
27 tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
28 host: String,
29) -> Result<JoinHandle<()>> {
30 let state = WsAppState {
31 metrics_rx,
32 auth_key,
33 };
34
35 let app = Router::new()
36 .route("/dashboard", get(serve_dashboard))
37 .route("/ws", get(ws_handler))
38 .route("/health", get(|| async { "OK" }))
39 .layer(TraceLayer::new_for_http())
40 .with_state(state);
41
42 let addr = format!("{host}:{port}");
43
44 match tls_config {
45 Some(tls_cfg) => {
46 let socket_addr: std::net::SocketAddr = addr
47 .parse()
48 .map_err(|e| anyhow!("Invalid bind address {addr} for TLS: {e}"))?;
49 let tls_listener = axum_server::bind_rustls(socket_addr, tls_cfg);
50 let handle = tokio::spawn(async move {
51 if let Err(e) = tls_listener.serve(app.into_make_service()).await {
52 error!("WebSocket server error: {e}");
53 }
54 });
55 info!("WebSocket server listening on https://{addr}");
56 Ok(handle)
57 }
58 None => {
59 let listener = tokio::net::TcpListener::bind(&addr)
60 .await
61 .with_context(|| format!("Failed to bind WebSocket server to {addr}"))?;
62 let handle = tokio::spawn(async move {
63 if let Err(e) = axum::serve(listener, app).await {
64 error!("WebSocket server error: {e}");
65 }
66 });
67 info!("WebSocket server listening on http://{addr}");
68 Ok(handle)
69 }
70 }
71}
72
73pub fn stop_ws_server(handle: JoinHandle<()>) {
74 handle.abort();
75}
76
77async fn serve_dashboard(
78 axum::extract::State(state): axum::extract::State<WsAppState>,
79) -> Html<String> {
80 let auth_json = serde_json::to_string(&state.auth_key).unwrap_or("null".to_string());
81 let auth_script = format!("<script>window.__WS_AUTH={};</script>", auth_json);
82 let html = include_str!("../dashboard.html");
83 Html(html.replacen("</body>", &format!("{}\n</body>", auth_script), 1))
84}
85
86async fn ws_handler(
87 ws: WebSocketUpgrade,
88 axum::extract::State(state): axum::extract::State<WsAppState>,
89 axum::extract::Query(query): axum::extract::Query<HashMap<String, String>>,
90) -> impl IntoResponse {
91 if let Some(ref expected) = state.auth_key {
92 if let Some(provided) = query.get("auth").and_then(|v| urlencoding::decode(v).ok()) {
93 if provided != *expected {
94 return StatusCode::UNAUTHORIZED.into_response();
95 }
96 } else {
97 return StatusCode::UNAUTHORIZED.into_response();
98 }
99 }
100 ws.on_upgrade(move |socket| handle_socket(socket, state))
101}
102
103async fn handle_socket(socket: WebSocket, state: WsAppState) {
104 let mut rx = state.metrics_rx.resubscribe();
105 info!("WebSocket client connected");
106
107 let (mut sender, mut receiver) = socket.split();
108
109 loop {
110 tokio::select! {
111 biased;
112 _ = receiver.next() => {
113 info!("WebSocket client disconnected");
114 break;
115 }
116 metrics = rx.recv() => match metrics {
117 Ok(m) => {
118 let json = match serde_json::to_string(&m) {
119 Ok(j) => j,
120 Err(e) => {
121 error!("Failed to serialize metrics: {e}");
122 continue;
123 }
124 };
125 if sender.send(Message::Text(json.into())).await.is_err() {
126 info!("WebSocket client disconnected");
127 break;
128 }
129 }
130 Err(broadcast::error::RecvError::Lagged(n)) => {
131 warn!("WebSocket client lagged behind, skipped {n} metrics");
132 }
133 Err(broadcast::error::RecvError::Closed) => {
134 break;
135 }
136 },
137 }
138 }
139}