intent_engine/dashboard/
websocket.rs

1// WebSocket support for Dashboard
2// Handles real-time communication between MCP servers and UI clients
3
4use axum::{
5    extract::{
6        ws::{Message, WebSocket},
7        State, WebSocketUpgrade,
8    },
9    response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// Project information sent by MCP servers
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20    pub path: String,
21    pub name: String,
22    pub db_path: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub agent: Option<String>,
25}
26
27/// MCP connection entry
28#[derive(Debug)]
29pub struct McpConnection {
30    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31    pub project: ProjectInfo,
32    pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35/// UI connection entry
36#[derive(Debug)]
37pub struct UiConnection {
38    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39    pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42/// Shared WebSocket state
43#[derive(Clone)]
44pub struct WebSocketState {
45    /// Project path → MCP connection
46    pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47    /// List of active UI connections
48    pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl WebSocketState {
58    pub fn new() -> Self {
59        Self {
60            mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61            ui_connections: Arc::new(RwLock::new(Vec::new())),
62        }
63    }
64
65    /// Broadcast message to all UI connections
66    pub async fn broadcast_to_ui(&self, message: &str) {
67        let connections = self.ui_connections.read().await;
68        for conn in connections.iter() {
69            let _ = conn.tx.send(Message::Text(message.to_string()));
70        }
71    }
72
73    /// Get list of currently connected projects from Registry
74    pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75        // Load from Registry to get accurate mcp_connected status
76        // This ensures UI gets complete project list even if WebSocket connections haven't been established yet
77        match crate::dashboard::registry::ProjectRegistry::load() {
78            Ok(registry) => registry
79                .projects
80                .iter()
81                .filter(|p| p.mcp_connected)
82                .map(|p| ProjectInfo {
83                    name: p.name.clone(),
84                    path: p.path.display().to_string(),
85                    db_path: p.db_path.display().to_string(),
86                    agent: p.mcp_agent.clone(),
87                })
88                .collect(),
89            Err(e) => {
90                tracing::warn!("Failed to load registry for online projects: {}", e);
91                Vec::new()
92            },
93        }
94    }
95}
96
97/// Message types from MCP to Dashboard
98#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101    #[serde(rename = "register")]
102    Register { project: ProjectInfo },
103    #[serde(rename = "ping")]
104    Ping,
105}
106
107/// Message types from Dashboard to MCP
108#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111    #[serde(rename = "registered")]
112    Registered { success: bool },
113    #[serde(rename = "pong")]
114    Pong,
115}
116
117/// Message types from Dashboard to UI
118#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121    #[serde(rename = "init")]
122    Init { projects: Vec<ProjectInfo> },
123    #[serde(rename = "project_online")]
124    ProjectOnline { project: ProjectInfo },
125    #[serde(rename = "project_offline")]
126    ProjectOffline { project_path: String },
127    #[serde(rename = "ping")]
128    Ping,
129}
130
131/// Handle MCP WebSocket connections
132pub async fn handle_mcp_websocket(
133    ws: WebSocketUpgrade,
134    State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136    ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140    let (mut sender, mut receiver) = socket.split();
141    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143    // Spawn task to forward messages from channel to WebSocket
144    let mut send_task = tokio::spawn(async move {
145        while let Some(msg) = rx.recv().await {
146            if sender.send(msg).await.is_err() {
147                break;
148            }
149        }
150    });
151
152    // Variables to track this connection
153    let mut project_path: Option<String> = None;
154
155    // Clone state for use inside recv_task
156    let state_for_recv = state.clone();
157
158    // Clone tx for heartbeat task
159    let heartbeat_tx = tx.clone();
160
161    // Spawn heartbeat task - send ping every 30 seconds
162    let mut heartbeat_task = tokio::spawn(async move {
163        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164        loop {
165            interval.tick().await;
166            let ping_msg = McpResponse::Pong; // Use Pong as keepalive for MCP
167            if heartbeat_tx
168                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169                .is_err()
170            {
171                // Connection closed
172                break;
173            }
174            tracing::trace!("Sent heartbeat to MCP client");
175        }
176    });
177
178    // Handle incoming messages
179    let mut recv_task = tokio::spawn(async move {
180        while let Some(Ok(msg)) = receiver.next().await {
181            match msg {
182                Message::Text(text) => {
183                    // Parse incoming message
184                    match serde_json::from_str::<McpMessage>(&text) {
185                        Ok(McpMessage::Register { project }) => {
186                            tracing::info!("MCP registering project: {}", project.name);
187
188                            // Store connection
189                            let path = project.path.clone();
190                            let conn = McpConnection {
191                                tx: tx.clone(),
192                                project: project.clone(),
193                                connected_at: chrono::Utc::now(),
194                            };
195
196                            state_for_recv
197                                .mcp_connections
198                                .write()
199                                .await
200                                .insert(path.clone(), conn);
201                            project_path = Some(path.clone());
202
203                            // Update Registry immediately to set mcp_connected=true
204                            let project_path_buf = std::path::PathBuf::from(&path);
205                            match crate::dashboard::registry::ProjectRegistry::load() {
206                                Ok(mut registry) => {
207                                    if let Err(e) = registry.register_mcp_connection(
208                                        &project_path_buf,
209                                        Some("mcp-client".to_string()),
210                                    ) {
211                                        tracing::warn!(
212                                            "Failed to update Registry for MCP connection: {}",
213                                            e
214                                        );
215                                    } else {
216                                        tracing::info!(
217                                            "✓ Updated Registry: {} is now mcp_connected=true",
218                                            project.name
219                                        );
220                                    }
221                                },
222                                Err(e) => {
223                                    tracing::warn!("Failed to load Registry: {}", e);
224                                },
225                            }
226
227                            // Send confirmation
228                            let response = McpResponse::Registered { success: true };
229                            let _ =
230                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
231
232                            // Broadcast to UI clients
233                            let ui_msg = UiMessage::ProjectOnline { project };
234                            state_for_recv
235                                .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
236                                .await;
237                        },
238                        Ok(McpMessage::Ping) => {
239                            // Respond with pong
240                            let response = McpResponse::Pong;
241                            let _ =
242                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
243                        },
244                        Err(e) => {
245                            tracing::warn!("Failed to parse MCP message: {}", e);
246                        },
247                    }
248                },
249                Message::Close(_) => {
250                    break;
251                },
252                _ => {},
253            }
254        }
255
256        project_path
257    });
258
259    // Wait for any task to finish
260    tokio::select! {
261        _ = (&mut send_task) => {
262            recv_task.abort();
263            heartbeat_task.abort();
264        }
265        project_path_result = (&mut recv_task) => {
266            send_task.abort();
267            heartbeat_task.abort();
268            if let Ok(Some(path)) = project_path_result {
269                // Clean up connection
270                state.mcp_connections.write().await.remove(&path);
271
272                // Update Registry immediately to set mcp_connected=false
273                let project_path_buf = std::path::PathBuf::from(&path);
274                match crate::dashboard::registry::ProjectRegistry::load() {
275                    Ok(mut registry) => {
276                        if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
277                            tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
278                        } else {
279                            tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
280                        }
281                    }
282                    Err(e) => {
283                        tracing::warn!("Failed to load Registry: {}", e);
284                    }
285                }
286
287                // Notify UI clients
288                let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
289                state
290                    .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
291                    .await;
292
293                tracing::info!("MCP disconnected: {}", path);
294            }
295        }
296        _ = (&mut heartbeat_task) => {
297            send_task.abort();
298            recv_task.abort();
299        }
300    }
301}
302
303/// Handle UI WebSocket connections
304pub async fn handle_ui_websocket(
305    ws: WebSocketUpgrade,
306    State(state): State<WebSocketState>,
307) -> impl IntoResponse {
308    ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
309}
310
311async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
312    let (mut sender, mut receiver) = socket.split();
313    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
314
315    // Spawn task to forward messages from channel to WebSocket
316    let mut send_task = tokio::spawn(async move {
317        while let Some(msg) = rx.recv().await {
318            if sender.send(msg).await.is_err() {
319                break;
320            }
321        }
322    });
323
324    // Send initial project list
325    let projects = state.get_online_projects().await;
326    let init_msg = UiMessage::Init { projects };
327    let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
328
329    // Register this UI connection
330    let conn = UiConnection {
331        tx: tx.clone(),
332        connected_at: chrono::Utc::now(),
333    };
334    let conn_index = {
335        let mut connections = state.ui_connections.write().await;
336        connections.push(conn);
337        connections.len() - 1
338    };
339
340    tracing::info!("UI client connected");
341
342    // Clone tx for heartbeat task
343    let heartbeat_tx = tx.clone();
344
345    // Spawn heartbeat task - send ping every 30 seconds
346    let mut heartbeat_task = tokio::spawn(async move {
347        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
348        loop {
349            interval.tick().await;
350            let ping_msg = UiMessage::Ping;
351            if heartbeat_tx
352                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
353                .is_err()
354            {
355                // Connection closed
356                break;
357            }
358            tracing::trace!("Sent heartbeat ping to UI client");
359        }
360    });
361
362    // Handle incoming messages (mostly just keep-alive and pong)
363    let mut recv_task = tokio::spawn(async move {
364        while let Some(Ok(msg)) = receiver.next().await {
365            match msg {
366                Message::Text(text) => {
367                    // UI can send pong or other messages
368                    tracing::trace!("Received from UI: {}", text);
369                },
370                Message::Pong(_) => {
371                    tracing::trace!("Received pong from UI");
372                },
373                Message::Close(_) => {
374                    break;
375                },
376                _ => {},
377            }
378        }
379    });
380
381    // Wait for any task to finish
382    tokio::select! {
383        _ = (&mut send_task) => {
384            recv_task.abort();
385            heartbeat_task.abort();
386        }
387        _ = (&mut recv_task) => {
388            send_task.abort();
389            heartbeat_task.abort();
390        }
391        _ = (&mut heartbeat_task) => {
392            send_task.abort();
393            recv_task.abort();
394        }
395    }
396
397    // Clean up UI connection
398    state.ui_connections.write().await.swap_remove(conn_index);
399    tracing::info!("UI client disconnected");
400}