batuta/serve/banco/
handlers_ws.rs1use axum::{
6 extract::{
7 ws::{Message, WebSocket},
8 State, WebSocketUpgrade,
9 },
10 response::Response,
11};
12
13use super::state::BancoState;
14
15pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<BancoState>) -> Response {
17 ws.on_upgrade(move |socket| handle_socket(socket, state))
18}
19
20async fn handle_socket(mut socket: WebSocket, state: BancoState) {
22 let mut rx = state.events.subscribe();
23
24 let welcome = serde_json::json!({
26 "type": "connected",
27 "data": {
28 "endpoints": 66,
29 "model_loaded": state.model.is_loaded(),
30 }
31 });
32 if socket.send(Message::Text(welcome.to_string())).await.is_err() {
33 return;
34 }
35
36 loop {
38 tokio::select! {
39 event = rx.recv() => {
41 match event {
42 Ok(json) => {
43 if socket.send(Message::Text(json)).await.is_err() {
44 break; }
46 }
47 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
48 let lag_msg = serde_json::json!({
50 "type": "system_event",
51 "data": {"message": format!("Missed {n} events (slow consumer)")}
52 });
53 let _ = socket.send(Message::Text(lag_msg.to_string())).await;
54 }
55 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
56 }
57 }
58 msg = socket.recv() => {
60 match msg {
61 Some(Ok(Message::Close(_))) | None => break,
62 Some(Ok(Message::Ping(data))) => {
63 let _ = socket.send(Message::Pong(data)).await;
64 }
65 _ => {} }
67 }
68 }
69 }
70}