1use crate::state::AppState;
2use axum::{
3 extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
4 response::IntoResponse,
5};
6use futures_util::{SinkExt, StreamExt};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9use tracing::warn;
10
11pub async fn ws_handler(
13 ws: WebSocketUpgrade,
14 State(state): State<Arc<AppState>>,
15) -> impl IntoResponse {
16 ws.on_upgrade(|socket| handle_ws(socket, state))
17}
18
19async fn handle_ws(socket: WebSocket, state: Arc<AppState>) {
20 let mut rx = state.tx.subscribe();
21 let (mut ws_sender, mut ws_receiver) = socket.split();
22
23 let connected = serde_json::json!({"type": "connected"});
25 if ws_sender
26 .send(Message::Text(connected.to_string()))
27 .await
28 .is_err()
29 {
30 return;
31 }
32
33 let state_clone = state.clone();
35 let reader = tokio::spawn(async move {
36 while let Some(Ok(msg)) = ws_receiver.next().await {
37 match msg {
38 Message::Text(text) => {
39 if let Ok(cmd) = serde_json::from_str::<serde_json::Value>(&text) {
41 let cmd_type = cmd.get("type").and_then(|v| v.as_str()).unwrap_or("");
42 match cmd_type {
43 "agent-status" => {
44 let payload = serde_json::json!({
46 "type": "agent-status",
47 "agent": cmd.get("agent"),
48 "status": cmd.get("status"),
49 "task": cmd.get("task"),
50 });
51 let _ = state_clone.tx.send(payload.to_string());
52 }
53 "ping" => {
54 }
56 _ => {
57 warn!("Unknown WS command: {cmd_type}");
58 }
59 }
60 }
61 }
62 Message::Close(_) => break,
63 _ => {}
64 }
65 }
66 });
67
68 let writer = tokio::spawn(async move {
70 loop {
71 match rx.recv().await {
72 Ok(data) => {
73 if ws_sender
74 .send(Message::Text(data))
75 .await
76 .is_err()
77 {
78 break;
79 }
80 }
81 Err(broadcast::error::RecvError::Lagged(n)) => {
82 warn!("WS client lagged by {n} messages");
83 }
84 Err(_) => break,
85 }
86 }
87 });
88
89 tokio::select! {
91 _ = reader => {},
92 _ = writer => {},
93 }
94}