intent_engine/dashboard/
websocket.rs

1// WebSocket support for Dashboard
2// Handles real-time communication between MCP servers and UI clients
3
4use axum::{
5    extract::{
6        ws::{Message, WebSocket},
7        State, WebSocketUpgrade,
8    },
9    response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// Project information sent by MCP servers
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20    pub path: String,
21    pub name: String,
22    pub db_path: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub agent: Option<String>,
25}
26
27/// MCP connection entry
28#[derive(Debug)]
29pub struct McpConnection {
30    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31    pub project: ProjectInfo,
32    pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35/// UI connection entry
36#[derive(Debug)]
37pub struct UiConnection {
38    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39    pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42/// Shared WebSocket state
43#[derive(Clone)]
44pub struct WebSocketState {
45    /// Project path → MCP connection
46    pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47    /// List of active UI connections
48    pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl WebSocketState {
58    pub fn new() -> Self {
59        Self {
60            mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61            ui_connections: Arc::new(RwLock::new(Vec::new())),
62        }
63    }
64
65    /// Broadcast message to all UI connections
66    pub async fn broadcast_to_ui(&self, message: &str) {
67        let connections = self.ui_connections.read().await;
68        for conn in connections.iter() {
69            let _ = conn.tx.send(Message::Text(message.to_string()));
70        }
71    }
72
73    /// Get list of currently connected projects from Registry
74    pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75        // Load from Registry to get accurate mcp_connected status
76        // This ensures UI gets complete project list even if WebSocket connections haven't been established yet
77        match crate::dashboard::registry::ProjectRegistry::load() {
78            Ok(registry) => registry
79                .projects
80                .iter()
81                .filter(|p| p.mcp_connected)
82                .map(|p| ProjectInfo {
83                    name: p.name.clone(),
84                    path: p.path.display().to_string(),
85                    db_path: p.db_path.display().to_string(),
86                    agent: p.mcp_agent.clone(),
87                })
88                .collect(),
89            Err(e) => {
90                tracing::warn!("Failed to load registry for online projects: {}", e);
91                Vec::new()
92            },
93        }
94    }
95}
96
97/// Message types from MCP to Dashboard
98#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101    #[serde(rename = "register")]
102    Register { project: ProjectInfo },
103    #[serde(rename = "ping")]
104    Ping,
105}
106
107/// Message types from Dashboard to MCP
108#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111    #[serde(rename = "registered")]
112    Registered { success: bool },
113    #[serde(rename = "pong")]
114    Pong,
115}
116
117/// Message types from Dashboard to UI
118#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121    #[serde(rename = "init")]
122    Init { projects: Vec<ProjectInfo> },
123    #[serde(rename = "project_online")]
124    ProjectOnline { project: ProjectInfo },
125    #[serde(rename = "project_offline")]
126    ProjectOffline { project_path: String },
127    #[serde(rename = "ping")]
128    Ping,
129}
130
131/// Handle MCP WebSocket connections
132pub async fn handle_mcp_websocket(
133    ws: WebSocketUpgrade,
134    State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136    ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140    let (mut sender, mut receiver) = socket.split();
141    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143    // Spawn task to forward messages from channel to WebSocket
144    let mut send_task = tokio::spawn(async move {
145        while let Some(msg) = rx.recv().await {
146            if sender.send(msg).await.is_err() {
147                break;
148            }
149        }
150    });
151
152    // Variables to track this connection
153    let mut project_path: Option<String> = None;
154
155    // Clone state for use inside recv_task
156    let state_for_recv = state.clone();
157
158    // Clone tx for heartbeat task
159    let heartbeat_tx = tx.clone();
160
161    // Spawn heartbeat task - send ping every 30 seconds
162    let mut heartbeat_task = tokio::spawn(async move {
163        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164        loop {
165            interval.tick().await;
166            let ping_msg = McpResponse::Pong; // Use Pong as keepalive for MCP
167            if heartbeat_tx
168                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169                .is_err()
170            {
171                // Connection closed
172                break;
173            }
174            tracing::trace!("Sent heartbeat to MCP client");
175        }
176    });
177
178    // Handle incoming messages
179    let mut recv_task = tokio::spawn(async move {
180        while let Some(Ok(msg)) = receiver.next().await {
181            match msg {
182                Message::Text(text) => {
183                    // Parse incoming message
184                    match serde_json::from_str::<McpMessage>(&text) {
185                        Ok(McpMessage::Register { project }) => {
186                            tracing::info!("MCP registering project: {}", project.name);
187
188                            let path = project.path.clone();
189                            let project_path_buf = std::path::PathBuf::from(&path);
190
191                            // Validate project path - reject temporary directories (Defense Layer 5)
192                            // This prevents test environments from polluting the Dashboard registry
193                            let normalized_path = project_path_buf
194                                .canonicalize()
195                                .unwrap_or_else(|_| project_path_buf.clone());
196
197                            // IMPORTANT: Canonicalize temp_dir to match normalized_path format (fixes Windows UNC paths)
198                            let temp_dir = std::env::temp_dir()
199                                .canonicalize()
200                                .unwrap_or_else(|_| std::env::temp_dir());
201                            let is_temp_path = normalized_path.starts_with(&temp_dir);
202
203                            if is_temp_path {
204                                tracing::warn!(
205                                    "Rejecting MCP registration for temporary/invalid path: {}",
206                                    path
207                                );
208
209                                // Send rejection response
210                                let response = McpResponse::Registered { success: false };
211                                let _ = tx
212                                    .send(Message::Text(serde_json::to_string(&response).unwrap()));
213                                continue; // Skip registration
214                            }
215
216                            // Store connection
217                            let conn = McpConnection {
218                                tx: tx.clone(),
219                                project: project.clone(),
220                                connected_at: chrono::Utc::now(),
221                            };
222
223                            state_for_recv
224                                .mcp_connections
225                                .write()
226                                .await
227                                .insert(path.clone(), conn);
228                            project_path = Some(path.clone());
229
230                            // Update Registry immediately to set mcp_connected=true
231                            match crate::dashboard::registry::ProjectRegistry::load() {
232                                Ok(mut registry) => {
233                                    if let Err(e) = registry.register_mcp_connection(
234                                        &project_path_buf,
235                                        project.agent.clone(),
236                                    ) {
237                                        tracing::warn!(
238                                            "Failed to update Registry for MCP connection: {}",
239                                            e
240                                        );
241                                    } else {
242                                        tracing::info!(
243                                            "✓ Updated Registry: {} is now mcp_connected=true",
244                                            project.name
245                                        );
246                                    }
247                                },
248                                Err(e) => {
249                                    tracing::warn!("Failed to load Registry: {}", e);
250                                },
251                            }
252
253                            // Send confirmation
254                            let response = McpResponse::Registered { success: true };
255                            let _ =
256                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
257
258                            // Broadcast to UI clients
259                            let ui_msg = UiMessage::ProjectOnline { project };
260                            state_for_recv
261                                .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
262                                .await;
263                        },
264                        Ok(McpMessage::Ping) => {
265                            // Respond with pong
266                            let response = McpResponse::Pong;
267                            let _ =
268                                tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
269                        },
270                        Err(e) => {
271                            tracing::warn!("Failed to parse MCP message: {}", e);
272                        },
273                    }
274                },
275                Message::Close(_) => {
276                    break;
277                },
278                _ => {},
279            }
280        }
281
282        project_path
283    });
284
285    // Wait for any task to finish
286    tokio::select! {
287        _ = (&mut send_task) => {
288            recv_task.abort();
289            heartbeat_task.abort();
290        }
291        project_path_result = (&mut recv_task) => {
292            send_task.abort();
293            heartbeat_task.abort();
294            if let Ok(Some(path)) = project_path_result {
295                // Clean up connection
296                state.mcp_connections.write().await.remove(&path);
297
298                // Update Registry immediately to set mcp_connected=false
299                let project_path_buf = std::path::PathBuf::from(&path);
300                match crate::dashboard::registry::ProjectRegistry::load() {
301                    Ok(mut registry) => {
302                        if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
303                            tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
304                        } else {
305                            tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
306                        }
307                    }
308                    Err(e) => {
309                        tracing::warn!("Failed to load Registry: {}", e);
310                    }
311                }
312
313                // Notify UI clients
314                let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
315                state
316                    .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
317                    .await;
318
319                tracing::info!("MCP disconnected: {}", path);
320            }
321        }
322        _ = (&mut heartbeat_task) => {
323            send_task.abort();
324            recv_task.abort();
325        }
326    }
327}
328
329/// Handle UI WebSocket connections
330pub async fn handle_ui_websocket(
331    ws: WebSocketUpgrade,
332    State(state): State<WebSocketState>,
333) -> impl IntoResponse {
334    ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
335}
336
337async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
338    let (mut sender, mut receiver) = socket.split();
339    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
340
341    // Spawn task to forward messages from channel to WebSocket
342    let mut send_task = tokio::spawn(async move {
343        while let Some(msg) = rx.recv().await {
344            if sender.send(msg).await.is_err() {
345                break;
346            }
347        }
348    });
349
350    // Send initial project list
351    let projects = state.get_online_projects().await;
352    let init_msg = UiMessage::Init { projects };
353    let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
354
355    // Register this UI connection
356    let conn = UiConnection {
357        tx: tx.clone(),
358        connected_at: chrono::Utc::now(),
359    };
360    let conn_index = {
361        let mut connections = state.ui_connections.write().await;
362        connections.push(conn);
363        connections.len() - 1
364    };
365
366    tracing::info!("UI client connected");
367
368    // Clone tx for heartbeat task
369    let heartbeat_tx = tx.clone();
370
371    // Spawn heartbeat task - send ping every 30 seconds
372    let mut heartbeat_task = tokio::spawn(async move {
373        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
374        loop {
375            interval.tick().await;
376            let ping_msg = UiMessage::Ping;
377            if heartbeat_tx
378                .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
379                .is_err()
380            {
381                // Connection closed
382                break;
383            }
384            tracing::trace!("Sent heartbeat ping to UI client");
385        }
386    });
387
388    // Handle incoming messages (mostly just keep-alive and pong)
389    let mut recv_task = tokio::spawn(async move {
390        while let Some(Ok(msg)) = receiver.next().await {
391            match msg {
392                Message::Text(text) => {
393                    // UI can send pong or other messages
394                    tracing::trace!("Received from UI: {}", text);
395                },
396                Message::Pong(_) => {
397                    tracing::trace!("Received pong from UI");
398                },
399                Message::Close(_) => {
400                    break;
401                },
402                _ => {},
403            }
404        }
405    });
406
407    // Wait for any task to finish
408    tokio::select! {
409        _ = (&mut send_task) => {
410            recv_task.abort();
411            heartbeat_task.abort();
412        }
413        _ = (&mut recv_task) => {
414            send_task.abort();
415            heartbeat_task.abort();
416        }
417        _ = (&mut heartbeat_task) => {
418            send_task.abort();
419            recv_task.abort();
420        }
421    }
422
423    // Clean up UI connection
424    state.ui_connections.write().await.swap_remove(conn_index);
425    tracing::info!("UI client disconnected");
426}