Skip to main content

flow_server/
ws.rs

1use crate::state::AppState;
2use axum::{
3    extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
4    response::IntoResponse,
5};
6use futures_util::{SinkExt, StreamExt};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9use tracing::warn;
10
11/// GET /api/ws — WebSocket endpoint for bidirectional communication
12pub async fn ws_handler(
13    ws: WebSocketUpgrade,
14    State(state): State<Arc<AppState>>,
15) -> impl IntoResponse {
16    ws.on_upgrade(|socket| handle_ws(socket, state))
17}
18
19async fn handle_ws(socket: WebSocket, state: Arc<AppState>) {
20    let mut rx = state.tx.subscribe();
21    let (mut ws_sender, mut ws_receiver) = socket.split();
22
23    // Send initial connected message
24    let connected = serde_json::json!({"type": "connected"});
25    if ws_sender
26        .send(Message::Text(connected.to_string()))
27        .await
28        .is_err()
29    {
30        return;
31    }
32
33    // Reader task: handle incoming messages from the client
34    let state_clone = state.clone();
35    let reader = tokio::spawn(async move {
36        while let Some(Ok(msg)) = ws_receiver.next().await {
37            match msg {
38                Message::Text(text) => {
39                    // Parse incoming commands from agents
40                    if let Ok(cmd) = serde_json::from_str::<serde_json::Value>(&text) {
41                        let cmd_type = cmd.get("type").and_then(|v| v.as_str()).unwrap_or("");
42                        match cmd_type {
43                            "agent-status" => {
44                                // Broadcast agent status update to all clients
45                                let payload = serde_json::json!({
46                                    "type": "agent-status",
47                                    "agent": cmd.get("agent"),
48                                    "status": cmd.get("status"),
49                                    "task": cmd.get("task"),
50                                });
51                                let _ = state_clone.tx.send(payload.to_string());
52                            }
53                            "ping" => {
54                                // Client keepalive, no action needed
55                            }
56                            _ => {
57                                warn!("Unknown WS command: {cmd_type}");
58                            }
59                        }
60                    }
61                }
62                Message::Close(_) => break,
63                _ => {}
64            }
65        }
66    });
67
68    // Writer task: forward broadcast events to the WebSocket client
69    let writer = tokio::spawn(async move {
70        loop {
71            match rx.recv().await {
72                Ok(data) => {
73                    if ws_sender
74                        .send(Message::Text(data))
75                        .await
76                        .is_err()
77                    {
78                        break;
79                    }
80                }
81                Err(broadcast::error::RecvError::Lagged(n)) => {
82                    warn!("WS client lagged by {n} messages");
83                }
84                Err(_) => break,
85            }
86        }
87    });
88
89    // Wait for either task to finish
90    tokio::select! {
91        _ = reader => {},
92        _ = writer => {},
93    }
94}