intent_engine/dashboard/
websocket.rs

1// WebSocket support for Dashboard
2// Handles real-time communication between MCP servers and UI clients
3
4use axum::{
5    extract::{
6        ws::{Message, WebSocket},
7        State, WebSocketUpgrade,
8    },
9    response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// Project information sent by MCP servers
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20    pub path: String,
21    pub name: String,
22    pub db_path: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub agent: Option<String>,
25}
26
27/// MCP connection entry
28#[derive(Debug)]
29pub struct McpConnection {
30    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31    pub project: ProjectInfo,
32    pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35/// UI connection entry
36#[derive(Debug)]
37pub struct UiConnection {
38    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39    pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42/// Shared WebSocket state
43#[derive(Clone)]
44pub struct WebSocketState {
45    /// Project path → MCP connection
46    pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47    /// List of active UI connections
48    pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl WebSocketState {
58    pub fn new() -> Self {
59        Self {
60            mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61            ui_connections: Arc::new(RwLock::new(Vec::new())),
62        }
63    }
64
65    /// Broadcast message to all UI connections
66    pub async fn broadcast_to_ui(&self, message: &str) {
67        let connections = self.ui_connections.read().await;
68        for conn in connections.iter() {
69            let _ = conn.tx.send(Message::Text(message.to_string()));
70        }
71    }
72
73    /// Get list of currently connected projects from Registry
74    pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75        // Load from Registry to get accurate mcp_connected status
76        // This ensures UI gets complete project list even if WebSocket connections haven't been established yet
77        match crate::dashboard::registry::ProjectRegistry::load() {
78            Ok(registry) => registry
79                .projects
80                .iter()
81                .filter(|p| p.mcp_connected)
82                .map(|p| ProjectInfo {
83                    name: p.name.clone(),
84                    path: p.path.display().to_string(),
85                    db_path: p.db_path.display().to_string(),
86                    agent: p.mcp_agent.clone(),
87                })
88                .collect(),
89            Err(e) => {
90                tracing::warn!("Failed to load registry for online projects: {}", e);
91                Vec::new()
92            },
93        }
94    }
95}
96
97/// Message types from MCP to Dashboard
98#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101    #[serde(rename = "register")]
102    Register { project: ProjectInfo },
103    #[serde(rename = "ping")]
104    Ping,
105}
106
107/// Message types from Dashboard to MCP
108#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111    #[serde(rename = "registered")]
112    Registered { success: bool },
113    #[serde(rename = "pong")]
114    Pong,
115}
116
117/// Message types from Dashboard to UI
118#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121    #[serde(rename = "init")]
122    Init { projects: Vec<ProjectInfo> },
123    #[serde(rename = "project_online")]
124    ProjectOnline { project: ProjectInfo },
125    #[serde(rename = "project_offline")]
126    ProjectOffline { project_path: String },
127    #[serde(rename = "ping")]
128    Ping,
129}
130
131/// Handle MCP WebSocket connections
132pub async fn handle_mcp_websocket(
133    ws: WebSocketUpgrade,
134    State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136    ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140    let (mut sender, mut receiver) = socket.split();
141    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143    // Spawn task to forward messages from channel to WebSocket
144    let mut send_task = tokio::spawn(async move {
145        while let Some(msg) = rx.recv().await {
146            if sender.send(msg).await.is_err() {
147                break;
148            }
149        }
150    });
151
152    // Variables to track this connection
153    let mut project_path: Option<String> = None;
154
155    // Clone state for use inside recv_task
156    let state_for_recv = state.clone();
157
158    // Clone tx for heartbeat task
159    let heartbeat_tx = tx.clone();
160
161    // Spawn heartbeat task - send ping every 30 seconds
162    let mut heartbeat_task = tokio::spawn(async move {
163        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164        loop {
165            interval.tick().await;
166            let ping_msg = McpResponse::Pong; // Use Pong as keepalive for MCP
167            if heartbeat_tx
168                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169                .is_err()
170            {
171                // Connection closed
172                break;
173            }
174            tracing::trace!("Sent heartbeat to MCP client");
175        }
176    });
177
178    // Handle incoming messages
179    let mut recv_task = tokio::spawn(async move {
180        while let Some(Ok(msg)) = receiver.next().await {
181            match msg {
182                Message::Text(text) => {
183                    // Parse incoming message
184                    match serde_json::from_str::<McpMessage>(&text) {
185                        Ok(McpMessage::Register { project }) => {
186                            tracing::info!("MCP registering project: {}", project.name);
187
188                            let path = project.path.clone();
189                            let project_path_buf = std::path::PathBuf::from(&path);
190
191                            // Validate project path - reject temporary directories (Defense Layer 5)
192                            // This prevents test environments from polluting the Dashboard registry
193                            let normalized_path = project_path_buf
194                                .canonicalize()
195                                .unwrap_or_else(|_| project_path_buf.clone());
196
197                            let temp_dir = std::env::temp_dir();
198                            let is_temp_path = normalized_path.starts_with(&temp_dir);
199
200                            if is_temp_path {
201                                tracing::warn!(
202                                    "Rejecting MCP registration for temporary/invalid path: {}",
203                                    path
204                                );
205
206                                // Send rejection response
207                                let response = McpResponse::Registered { success: false };
208                                let _ = tx
209                                    .send(Message::Text(serde_json::to_string(&response).unwrap()));
210                                continue; // Skip registration
211                            }
212
213                            // Store connection
214                            let conn = McpConnection {
215                                tx: tx.clone(),
216                                project: project.clone(),
217                                connected_at: chrono::Utc::now(),
218                            };
219
220                            state_for_recv
221                                .mcp_connections
222                                .write()
223                                .await
224                                .insert(path.clone(), conn);
225                            project_path = Some(path.clone());
226
227                            // Update Registry immediately to set mcp_connected=true
228                            match crate::dashboard::registry::ProjectRegistry::load() {
229                                Ok(mut registry) => {
230                                    if let Err(e) = registry.register_mcp_connection(
231                                        &project_path_buf,
232                                        project.agent.clone(),
233                                    ) {
234                                        tracing::warn!(
235                                            "Failed to update Registry for MCP connection: {}",
236                                            e
237                                        );
238                                    } else {
239                                        tracing::info!(
240                                            "✓ Updated Registry: {} is now mcp_connected=true",
241                                            project.name
242                                        );
243                                    }
244                                },
245                                Err(e) => {
246                                    tracing::warn!("Failed to load Registry: {}", e);
247                                },
248                            }
249
250                            // Send confirmation
251                            let response = McpResponse::Registered { success: true };
252                            let _ =
253                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
254
255                            // Broadcast to UI clients
256                            let ui_msg = UiMessage::ProjectOnline { project };
257                            state_for_recv
258                                .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
259                                .await;
260                        },
261                        Ok(McpMessage::Ping) => {
262                            // Respond with pong
263                            let response = McpResponse::Pong;
264                            let _ =
265                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
266                        },
267                        Err(e) => {
268                            tracing::warn!("Failed to parse MCP message: {}", e);
269                        },
270                    }
271                },
272                Message::Close(_) => {
273                    break;
274                },
275                _ => {},
276            }
277        }
278
279        project_path
280    });
281
282    // Wait for any task to finish
283    tokio::select! {
284        _ = (&mut send_task) => {
285            recv_task.abort();
286            heartbeat_task.abort();
287        }
288        project_path_result = (&mut recv_task) => {
289            send_task.abort();
290            heartbeat_task.abort();
291            if let Ok(Some(path)) = project_path_result {
292                // Clean up connection
293                state.mcp_connections.write().await.remove(&path);
294
295                // Update Registry immediately to set mcp_connected=false
296                let project_path_buf = std::path::PathBuf::from(&path);
297                match crate::dashboard::registry::ProjectRegistry::load() {
298                    Ok(mut registry) => {
299                        if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
300                            tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
301                        } else {
302                            tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
303                        }
304                    }
305                    Err(e) => {
306                        tracing::warn!("Failed to load Registry: {}", e);
307                    }
308                }
309
310                // Notify UI clients
311                let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
312                state
313                    .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
314                    .await;
315
316                tracing::info!("MCP disconnected: {}", path);
317            }
318        }
319        _ = (&mut heartbeat_task) => {
320            send_task.abort();
321            recv_task.abort();
322        }
323    }
324}
325
326/// Handle UI WebSocket connections
327pub async fn handle_ui_websocket(
328    ws: WebSocketUpgrade,
329    State(state): State<WebSocketState>,
330) -> impl IntoResponse {
331    ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
332}
333
334async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
335    let (mut sender, mut receiver) = socket.split();
336    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
337
338    // Spawn task to forward messages from channel to WebSocket
339    let mut send_task = tokio::spawn(async move {
340        while let Some(msg) = rx.recv().await {
341            if sender.send(msg).await.is_err() {
342                break;
343            }
344        }
345    });
346
347    // Send initial project list
348    let projects = state.get_online_projects().await;
349    let init_msg = UiMessage::Init { projects };
350    let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
351
352    // Register this UI connection
353    let conn = UiConnection {
354        tx: tx.clone(),
355        connected_at: chrono::Utc::now(),
356    };
357    let conn_index = {
358        let mut connections = state.ui_connections.write().await;
359        connections.push(conn);
360        connections.len() - 1
361    };
362
363    tracing::info!("UI client connected");
364
365    // Clone tx for heartbeat task
366    let heartbeat_tx = tx.clone();
367
368    // Spawn heartbeat task - send ping every 30 seconds
369    let mut heartbeat_task = tokio::spawn(async move {
370        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
371        loop {
372            interval.tick().await;
373            let ping_msg = UiMessage::Ping;
374            if heartbeat_tx
375                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
376                .is_err()
377            {
378                // Connection closed
379                break;
380            }
381            tracing::trace!("Sent heartbeat ping to UI client");
382        }
383    });
384
385    // Handle incoming messages (mostly just keep-alive and pong)
386    let mut recv_task = tokio::spawn(async move {
387        while let Some(Ok(msg)) = receiver.next().await {
388            match msg {
389                Message::Text(text) => {
390                    // UI can send pong or other messages
391                    tracing::trace!("Received from UI: {}", text);
392                },
393                Message::Pong(_) => {
394                    tracing::trace!("Received pong from UI");
395                },
396                Message::Close(_) => {
397                    break;
398                },
399                _ => {},
400            }
401        }
402    });
403
404    // Wait for any task to finish
405    tokio::select! {
406        _ = (&mut send_task) => {
407            recv_task.abort();
408            heartbeat_task.abort();
409        }
410        _ = (&mut recv_task) => {
411            send_task.abort();
412            heartbeat_task.abort();
413        }
414        _ = (&mut heartbeat_task) => {
415            send_task.abort();
416            recv_task.abort();
417        }
418    }
419
420    // Clean up UI connection
421    state.ui_connections.write().await.swap_remove(conn_index);
422    tracing::info!("UI client disconnected");
423}