llm-manager 1.1.0

Terminal UI for managing LLMs
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::{Context, Result, anyhow};
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::{Router, response::Html, routing::get};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tower_http::trace::TraceLayer;
use tracing::{error, info, warn};

use crate::models::WsMetrics;

#[derive(Clone)]
pub struct WsAppState {
    pub metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
    pub auth_key: Option<String>,
}

pub async fn start_ws_server(
    port: u16,
    metrics_rx: Arc<broadcast::Receiver<WsMetrics>>,
    auth_key: Option<String>,
    tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
    host: String,
) -> Result<JoinHandle<()>> {
    let state = WsAppState {
        metrics_rx,
        auth_key,
    };

    let app = Router::new()
        .route("/dashboard", get(serve_dashboard))
        .route("/ws", get(ws_handler))
        .route("/health", get(|| async { "OK" }))
        .layer(TraceLayer::new_for_http())
        .with_state(state);

    let addr = format!("{host}:{port}");

    match tls_config {
        Some(tls_cfg) => {
            let socket_addr: std::net::SocketAddr = addr
                .parse()
                .map_err(|e| anyhow!("Invalid bind address {addr} for TLS: {e}"))?;
            let tls_listener = axum_server::bind_rustls(socket_addr, tls_cfg);
            let handle = tokio::spawn(async move {
                if let Err(e) = tls_listener.serve(app.into_make_service()).await {
                    error!("WebSocket server error: {e}");
                }
            });
            info!("WebSocket server listening on https://{addr}");
            Ok(handle)
        }
        None => {
            let listener = tokio::net::TcpListener::bind(&addr)
                .await
                .with_context(|| format!("Failed to bind WebSocket server to {addr}"))?;
            let handle = tokio::spawn(async move {
                if let Err(e) = axum::serve(listener, app).await {
                    error!("WebSocket server error: {e}");
                }
            });
            info!("WebSocket server listening on http://{addr}");
            Ok(handle)
        }
    }
}

pub fn stop_ws_server(handle: JoinHandle<()>) {
    handle.abort();
}

async fn serve_dashboard(
    axum::extract::State(state): axum::extract::State<WsAppState>,
) -> Html<String> {
    let auth_json = serde_json::to_string(&state.auth_key).unwrap_or("null".to_string());
    let auth_script = format!("<script>window.__WS_AUTH={};</script>", auth_json);
    let html = include_str!("../dashboard.html");
    Html(html.replacen("</body>", &format!("{}\n</body>", auth_script), 1))
}

async fn ws_handler(
    ws: WebSocketUpgrade,
    axum::extract::State(state): axum::extract::State<WsAppState>,
    axum::extract::Query(query): axum::extract::Query<HashMap<String, String>>,
) -> impl IntoResponse {
    if let Some(ref expected) = state.auth_key {
        if let Some(provided) = query.get("auth").and_then(|v| urlencoding::decode(v).ok()) {
            if provided != *expected {
                return StatusCode::UNAUTHORIZED.into_response();
            }
        } else {
            return StatusCode::UNAUTHORIZED.into_response();
        }
    }
    ws.on_upgrade(move |socket| handle_socket(socket, state))
}

async fn handle_socket(socket: WebSocket, state: WsAppState) {
    let mut rx = state.metrics_rx.resubscribe();
    info!("WebSocket client connected");

    let (mut sender, mut receiver) = socket.split();

    loop {
        tokio::select! {
            biased;
            _ = receiver.next() => {
                info!("WebSocket client disconnected");
                break;
            }
            metrics = rx.recv() => match metrics {
                Ok(m) => {
                    let json = match serde_json::to_string(&m) {
                        Ok(j) => j,
                        Err(e) => {
                            error!("Failed to serialize metrics: {e}");
                            continue;
                        }
                    };
                    if sender.send(Message::Text(json.into())).await.is_err() {
                        info!("WebSocket client disconnected");
                        break;
                    }
                }
                Err(broadcast::error::RecvError::Lagged(n)) => {
                    warn!("WebSocket client lagged behind, skipped {n} metrics");
                }
                Err(broadcast::error::RecvError::Closed) => {
                    break;
                }
            },
        }
    }
}