use axum::{
Json,
extract::{
FromRequestParts, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::{IntoResponse, Response},
};
use tokio::sync::broadcast;
use crate::{
models::{ApiResponse, OutputMode},
state::AppState,
};
pub async fn ws_handler(State(state): State<AppState>, req: axum::extract::Request) -> Response {
if *state.output_mode_tx.borrow() == OutputMode::Dump {
return (
StatusCode::FORBIDDEN,
Json(ApiResponse {
success: false,
message: "Server is in dump-only mode; WebSocket streaming is disabled".to_string(),
}),
)
.into_response();
}
let (mut parts, _body) = req.into_parts();
let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await {
Ok(ws) => ws,
Err(rejection) => return rejection.into_response(),
};
let rx = state.csi_tx.subscribe();
ws.on_upgrade(|socket| handle_socket(socket, rx))
.into_response()
}
async fn handle_socket(mut socket: WebSocket, mut rx: broadcast::Receiver<Vec<u8>>) {
loop {
tokio::select! {
result = rx.recv() => {
match result {
Ok(data) => {
if socket.send(Message::Binary(data.into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("WebSocket client lagged — dropped {n} CSI packets");
}
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
_ => {} }
}
}
}
tracing::debug!("WebSocket client disconnected");
}