Skip to main content

flow_server/
ws.rs

1use crate::state::AppState;
2use axum::{
3    extract::{
4        ws::{Message, WebSocket, WebSocketUpgrade},
5        State,
6    },
7    response::IntoResponse,
8};
9use futures_util::{SinkExt, StreamExt};
10use std::sync::Arc;
11use tokio::sync::broadcast;
12use tracing::warn;
13
14/// GET /api/ws — WebSocket endpoint for bidirectional communication
15pub async fn ws_handler(
16    ws: WebSocketUpgrade,
17    State(state): State<Arc<AppState>>,
18) -> impl IntoResponse {
19    ws.on_upgrade(|socket| handle_ws(socket, state))
20}
21
22async fn handle_ws(socket: WebSocket, state: Arc<AppState>) {
23    let mut rx = state.tx.subscribe();
24    let (mut ws_sender, mut ws_receiver) = socket.split();
25
26    // Send initial connected message
27    let connected = serde_json::json!({"type": "connected"});
28    if ws_sender
29        .send(Message::Text(connected.to_string()))
30        .await
31        .is_err()
32    {
33        return;
34    }
35
36    // Reader task: handle incoming messages from the client
37    let state_clone = state.clone();
38    let reader = tokio::spawn(async move {
39        while let Some(Ok(msg)) = ws_receiver.next().await {
40            match msg {
41                Message::Text(text) => {
42                    // Parse incoming commands from agents
43                    if let Ok(cmd) = serde_json::from_str::<serde_json::Value>(&text) {
44                        let cmd_type = cmd.get("type").and_then(|v| v.as_str()).unwrap_or("");
45                        match cmd_type {
46                            "agent-status" => {
47                                // Broadcast agent status update to all clients
48                                let payload = serde_json::json!({
49                                    "type": "agent-status",
50                                    "agent": cmd.get("agent"),
51                                    "status": cmd.get("status"),
52                                    "task": cmd.get("task"),
53                                });
54                                let _ = state_clone.tx.send(payload.to_string());
55                            }
56                            "ping" => {
57                                // Client keepalive, no action needed
58                            }
59                            _ => {
60                                warn!("Unknown WS command: {cmd_type}");
61                            }
62                        }
63                    }
64                }
65                Message::Close(_) => break,
66                _ => {}
67            }
68        }
69    });
70
71    // Writer task: forward broadcast events to the WebSocket client
72    let writer = tokio::spawn(async move {
73        loop {
74            match rx.recv().await {
75                Ok(data) => {
76                    if ws_sender.send(Message::Text(data)).await.is_err() {
77                        break;
78                    }
79                }
80                Err(broadcast::error::RecvError::Lagged(n)) => {
81                    warn!("WS client lagged by {n} messages");
82                }
83                Err(_) => break,
84            }
85        }
86    });
87
88    // Wait for either task to finish
89    tokio::select! {
90        _ = reader => {},
91        _ = writer => {},
92    }
93}