1use crate::state::AppState;
2use axum::{
3 extract::{
4 ws::{Message, WebSocket, WebSocketUpgrade},
5 State,
6 },
7 response::IntoResponse,
8};
9use futures_util::{SinkExt, StreamExt};
10use std::sync::Arc;
11use tokio::sync::broadcast;
12use tracing::warn;
13
14pub async fn ws_handler(
16 ws: WebSocketUpgrade,
17 State(state): State<Arc<AppState>>,
18) -> impl IntoResponse {
19 ws.on_upgrade(|socket| handle_ws(socket, state))
20}
21
22async fn handle_ws(socket: WebSocket, state: Arc<AppState>) {
23 let mut rx = state.tx.subscribe();
24 let (mut ws_sender, mut ws_receiver) = socket.split();
25
26 let connected = serde_json::json!({"type": "connected"});
28 if ws_sender
29 .send(Message::Text(connected.to_string()))
30 .await
31 .is_err()
32 {
33 return;
34 }
35
36 let state_clone = state.clone();
38 let reader = tokio::spawn(async move {
39 while let Some(Ok(msg)) = ws_receiver.next().await {
40 match msg {
41 Message::Text(text) => {
42 if let Ok(cmd) = serde_json::from_str::<serde_json::Value>(&text) {
44 let cmd_type = cmd.get("type").and_then(|v| v.as_str()).unwrap_or("");
45 match cmd_type {
46 "agent-status" => {
47 let payload = serde_json::json!({
49 "type": "agent-status",
50 "agent": cmd.get("agent"),
51 "status": cmd.get("status"),
52 "task": cmd.get("task"),
53 });
54 let _ = state_clone.tx.send(payload.to_string());
55 }
56 "ping" => {
57 }
59 _ => {
60 warn!("Unknown WS command: {cmd_type}");
61 }
62 }
63 }
64 }
65 Message::Close(_) => break,
66 _ => {}
67 }
68 }
69 });
70
71 let writer = tokio::spawn(async move {
73 loop {
74 match rx.recv().await {
75 Ok(data) => {
76 if ws_sender.send(Message::Text(data)).await.is_err() {
77 break;
78 }
79 }
80 Err(broadcast::error::RecvError::Lagged(n)) => {
81 warn!("WS client lagged by {n} messages");
82 }
83 Err(_) => break,
84 }
85 }
86 });
87
88 tokio::select! {
90 _ = reader => {},
91 _ = writer => {},
92 }
93}