intent_engine/dashboard/
websocket.rs

1// WebSocket support for Dashboard
2// Handles real-time communication between MCP servers and UI clients
3
4/// Intent-Engine Protocol Version
5pub const PROTOCOL_VERSION: &str = "1.0";
6
7use axum::{
8    extract::{
9        ws::{Message, WebSocket},
10        State, WebSocketUpgrade,
11    },
12    response::IntoResponse,
13};
14use futures_util::{SinkExt, StreamExt};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20/// Protocol message wrapper - wraps all WebSocket messages with version and timestamp
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ProtocolMessage<T> {
23    /// Protocol version (e.g., "1.0")
24    pub version: String,
25    /// Message type identifier
26    #[serde(rename = "type")]
27    pub message_type: String,
28    /// Message payload
29    pub payload: T,
30    /// ISO 8601 timestamp when message was created
31    pub timestamp: String,
32}
33
34impl<T> ProtocolMessage<T>
35where
36    T: Serialize,
37{
38    /// Create a new protocol message with current timestamp
39    pub fn new(message_type: impl Into<String>, payload: T) -> Self {
40        Self {
41            version: PROTOCOL_VERSION.to_string(),
42            message_type: message_type.into(),
43            payload,
44            timestamp: chrono::Utc::now().to_rfc3339(),
45        }
46    }
47
48    /// Serialize to JSON string
49    pub fn to_json(&self) -> Result<String, serde_json::Error> {
50        serde_json::to_string(self)
51    }
52}
53
54impl<T> ProtocolMessage<T>
55where
56    T: for<'de> Deserialize<'de>,
57{
58    /// Deserialize from JSON string with version validation
59    pub fn from_json(json: &str) -> Result<Self, String> {
60        let msg: Self = serde_json::from_str(json)
61            .map_err(|e| format!("Failed to parse protocol message: {}", e))?;
62
63        // Validate protocol version (major version must match)
64        let expected_major = PROTOCOL_VERSION.split('.').next().unwrap_or("1");
65        let received_major = msg.version.split('.').next().unwrap_or("0");
66
67        if expected_major != received_major {
68            return Err(format!(
69                "Protocol version mismatch: expected {}, got {}",
70                PROTOCOL_VERSION, msg.version
71            ));
72        }
73
74        Ok(msg)
75    }
76}
77
78/// Project information sent by MCP servers
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProjectInfo {
81    pub path: String,
82    pub name: String,
83    pub db_path: String,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub agent: Option<String>,
86    /// Whether this project has an active MCP connection
87    pub mcp_connected: bool,
88    /// Whether the Dashboard serving this project is online
89    pub is_online: bool,
90}
91
92/// MCP connection entry
93#[derive(Debug)]
94pub struct McpConnection {
95    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
96    pub project: ProjectInfo,
97    pub connected_at: chrono::DateTime<chrono::Utc>,
98}
99
100/// UI connection entry
101#[derive(Debug)]
102pub struct UiConnection {
103    pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
104    pub connected_at: chrono::DateTime<chrono::Utc>,
105}
106
107/// Shared WebSocket state
108#[derive(Clone)]
109pub struct WebSocketState {
110    /// Project path → MCP connection
111    pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
112    /// List of active UI connections
113    pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
114}
115
116impl Default for WebSocketState {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl WebSocketState {
123    pub fn new() -> Self {
124        Self {
125            mcp_connections: Arc::new(RwLock::new(HashMap::new())),
126            ui_connections: Arc::new(RwLock::new(Vec::new())),
127        }
128    }
129
130    /// Broadcast message to all UI connections
131    pub async fn broadcast_to_ui(&self, message: &str) {
132        let connections = self.ui_connections.read().await;
133        for conn in connections.iter() {
134            let _ = conn.tx.send(Message::Text(message.to_string()));
135        }
136    }
137
138    /// Get list of all online projects from in-memory state
139    pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
140        // Read from in-memory MCP connections
141        let connections = self.mcp_connections.read().await;
142
143        connections
144            .values()
145            .map(|conn| {
146                let mut project = conn.project.clone();
147                project.mcp_connected = true; // All projects in the map are connected
148                project
149            })
150            .collect()
151    }
152
153    /// Get list of all online projects
154    /// Always includes the current Dashboard project plus all MCP-connected projects
155    /// This is the single source of truth for project status
156    pub async fn get_online_projects_with_current(
157        &self,
158        current_project_name: &str,
159        current_project_path: &std::path::Path,
160        current_db_path: &std::path::Path,
161        _port: u16,
162    ) -> Vec<ProjectInfo> {
163        let connections = self.mcp_connections.read().await;
164        let current_path_str = current_project_path.display().to_string();
165
166        let mut projects = Vec::new();
167
168        // 1. Always add current Dashboard project first
169        // Check if this project also has an MCP connection
170        let current_has_mcp = connections
171            .values()
172            .any(|conn| conn.project.path == current_path_str);
173
174        projects.push(ProjectInfo {
175            name: current_project_name.to_string(),
176            path: current_path_str.clone(),
177            db_path: current_db_path.display().to_string(),
178            agent: None, // Dashboard itself doesn't have an agent name
179            mcp_connected: current_has_mcp,
180            is_online: true, // Dashboard is online (serving this response)
181        });
182
183        // 2. Add all other MCP-connected projects (excluding current project to avoid duplication)
184        for conn in connections.values() {
185            if conn.project.path != current_path_str {
186                let mut project = conn.project.clone();
187                project.mcp_connected = true;
188                project.is_online = true; // MCP connection means project is online
189                projects.push(project);
190            }
191        }
192
193        projects
194    }
195}
196
197// ============================================================================
198// Payload Structures (used inside ProtocolMessage)
199// ============================================================================
200
201/// Payload for MCP register message
202#[derive(Debug, Serialize, Deserialize)]
203pub struct RegisterPayload {
204    pub project: ProjectInfo,
205}
206
207/// Payload for MCP registered response
208#[derive(Debug, Serialize, Deserialize)]
209pub struct RegisteredPayload {
210    pub success: bool,
211}
212
213/// Empty payload for ping/pong messages
214#[derive(Debug, Serialize, Deserialize)]
215pub struct EmptyPayload {}
216
217/// Payload for UI init message
218#[derive(Debug, Serialize, Deserialize)]
219pub struct InitPayload {
220    pub projects: Vec<ProjectInfo>,
221}
222
223/// Payload for UI project_online message
224#[derive(Debug, Serialize, Deserialize)]
225pub struct ProjectOnlinePayload {
226    pub project: ProjectInfo,
227}
228
229/// Payload for UI project_offline message
230#[derive(Debug, Serialize, Deserialize)]
231pub struct ProjectOfflinePayload {
232    pub project_path: String,
233}
234
235/// Payload for hello message (client → server)
236#[derive(Debug, Serialize, Deserialize)]
237pub struct HelloPayload {
238    /// Client entity type ("mcp" or "ui")
239    pub entity_type: String,
240    /// Client capabilities (optional)
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub capabilities: Option<Vec<String>>,
243}
244
245/// Payload for welcome message (server → client)
246#[derive(Debug, Serialize, Deserialize)]
247pub struct WelcomePayload {
248    /// Server capabilities
249    pub capabilities: Vec<String>,
250    /// Session ID
251    pub session_id: String,
252}
253
254/// Payload for goodbye message
255#[derive(Debug, Serialize, Deserialize)]
256pub struct GoodbyePayload {
257    /// Reason for closing (optional)
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub reason: Option<String>,
260}
261
262/// Payload for error message (Protocol v1.0 Section 4.5)
263#[derive(Debug, Serialize, Deserialize)]
264pub struct ErrorPayload {
265    /// Machine-readable error code
266    pub code: String,
267    /// Human-readable error message
268    pub message: String,
269    /// Optional additional details (for debugging)
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub details: Option<serde_json::Value>,
272}
273
274/// Standard error codes (Protocol v1.0 Section 4.5)
275pub mod error_codes {
276    pub const UNSUPPORTED_VERSION: &str = "unsupported_version";
277    pub const INVALID_MESSAGE: &str = "invalid_message";
278    pub const INVALID_PATH: &str = "invalid_path";
279    pub const REGISTRATION_FAILED: &str = "registration_failed";
280    pub const INTERNAL_ERROR: &str = "internal_error";
281}
282
283// ============================================================================
284// Helper Functions for Sending Protocol Messages
285// ============================================================================
286
287/// Send a protocol message through a channel
288fn send_protocol_message<T: Serialize>(
289    tx: &tokio::sync::mpsc::UnboundedSender<Message>,
290    message_type: &str,
291    payload: T,
292) -> Result<(), String> {
293    let protocol_msg = ProtocolMessage::new(message_type, payload);
294    let json = protocol_msg
295        .to_json()
296        .map_err(|e| format!("Failed to serialize message: {}", e))?;
297
298    tx.send(Message::Text(json))
299        .map_err(|_| "Failed to send message: channel closed".to_string())
300}
301
302/// Handle MCP WebSocket connections
303pub async fn handle_mcp_websocket(
304    ws: WebSocketUpgrade,
305    State(app_state): State<crate::dashboard::server::AppState>,
306) -> impl IntoResponse {
307    ws.on_upgrade(move |socket| handle_mcp_socket(socket, app_state.ws_state))
308}
309
310async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
311    let (mut sender, mut receiver) = socket.split();
312    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
313
314    // Spawn task to forward messages from channel to WebSocket
315    let mut send_task = tokio::spawn(async move {
316        while let Some(msg) = rx.recv().await {
317            if sender.send(msg).await.is_err() {
318                break;
319            }
320        }
321    });
322
323    // Variables to track this connection
324    let mut project_path: Option<String> = None;
325    let mut session_welcomed = false; // Track if welcome handshake completed
326
327    // Clone state for use inside recv_task
328    let state_for_recv = state.clone();
329
330    // Clone tx for heartbeat task
331    let heartbeat_tx = tx.clone();
332
333    // Spawn heartbeat task - send ping every 30 seconds (Protocol v1.0 Section 4.1.3)
334    let mut heartbeat_task = tokio::spawn(async move {
335        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
336        // Skip the first tick (which completes immediately)
337        interval.tick().await;
338
339        loop {
340            interval.tick().await;
341            // Send ping to request heartbeat from client
342            if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
343                // Connection closed
344                break;
345            }
346            tracing::trace!("Sent heartbeat ping to MCP client");
347        }
348    });
349
350    // Handle incoming messages
351    let mut recv_task = tokio::spawn(async move {
352        while let Some(Ok(msg)) = receiver.next().await {
353            match msg {
354                Message::Text(text) => {
355                    // Parse incoming protocol message
356                    let parsed_msg = match ProtocolMessage::<serde_json::Value>::from_json(&text) {
357                        Ok(msg) => msg,
358                        Err(e) => {
359                            tracing::warn!("Protocol error: {}", e);
360
361                            // Send error message to client
362                            let error_code = if e.contains("version mismatch") {
363                                error_codes::UNSUPPORTED_VERSION
364                            } else {
365                                error_codes::INVALID_MESSAGE
366                            };
367
368                            let error_payload = ErrorPayload {
369                                code: error_code.to_string(),
370                                message: e.to_string(),
371                                details: None,
372                            };
373
374                            let _ = send_protocol_message(&tx, "error", error_payload);
375                            continue;
376                        },
377                    };
378
379                    match parsed_msg.message_type.as_str() {
380                        "hello" => {
381                            // Parse hello payload
382                            let hello: HelloPayload =
383                                match serde_json::from_value(parsed_msg.payload.clone()) {
384                                    Ok(h) => h,
385                                    Err(e) => {
386                                        tracing::warn!("Failed to parse hello payload: {}", e);
387                                        continue;
388                                    },
389                                };
390
391                            tracing::info!("Received hello from {} client", hello.entity_type);
392
393                            // Generate session ID
394                            let session_id = format!(
395                                "{}-{}",
396                                hello.entity_type,
397                                chrono::Utc::now().timestamp_millis()
398                            );
399
400                            // Send welcome response
401                            let welcome_payload = WelcomePayload {
402                                session_id,
403                                capabilities: vec![], // TODO: Add actual capabilities
404                            };
405
406                            if send_protocol_message(&tx, "welcome", welcome_payload).is_ok() {
407                                session_welcomed = true;
408                                tracing::debug!("Sent welcome message");
409                            } else {
410                                tracing::error!("Failed to send welcome message");
411                            }
412                        },
413                        "register" => {
414                            // Check if handshake completed (backward compatibility: allow register without hello for now)
415                            if !session_welcomed {
416                                tracing::warn!(
417                                    "MCP client registered without hello handshake (legacy client detected)"
418                                );
419                            }
420
421                            // Parse register payload
422                            let project: ProjectInfo =
423                                match serde_json::from_value(parsed_msg.payload.clone()) {
424                                    Ok(p) => p,
425                                    Err(e) => {
426                                        tracing::warn!("Failed to parse register payload: {}", e);
427                                        continue;
428                                    },
429                                };
430                            tracing::info!("MCP registering project: {}", project.name);
431
432                            let path = project.path.clone();
433                            let project_path_buf = std::path::PathBuf::from(&path);
434
435                            // Validate project path - reject temporary directories (Defense Layer 5)
436                            // This prevents test environments from polluting the Dashboard registry
437                            let normalized_path = project_path_buf
438                                .canonicalize()
439                                .unwrap_or_else(|_| project_path_buf.clone());
440
441                            // IMPORTANT: Canonicalize temp_dir to match normalized_path format (fixes Windows UNC paths)
442                            let temp_dir = std::env::temp_dir()
443                                .canonicalize()
444                                .unwrap_or_else(|_| std::env::temp_dir());
445                            let is_temp_path = normalized_path.starts_with(&temp_dir);
446
447                            if is_temp_path {
448                                tracing::warn!(
449                                    "Rejecting MCP registration for temporary/invalid path: {}",
450                                    path
451                                );
452
453                                // Send error message
454                                let error_payload = ErrorPayload {
455                                    code: error_codes::INVALID_PATH.to_string(),
456                                    message: "Path is in temporary directory".to_string(),
457                                    details: Some(serde_json::json!({"path": path})),
458                                };
459                                let _ = send_protocol_message(&tx, "error", error_payload);
460
461                                // Send rejection response
462                                let _ = send_protocol_message(
463                                    &tx,
464                                    "registered",
465                                    RegisteredPayload { success: false },
466                                );
467                                continue; // Skip registration
468                            }
469
470                            // Store connection
471                            let conn = McpConnection {
472                                tx: tx.clone(),
473                                project: project.clone(),
474                                connected_at: chrono::Utc::now(),
475                            };
476
477                            state_for_recv
478                                .mcp_connections
479                                .write()
480                                .await
481                                .insert(path.clone(), conn);
482                            project_path = Some(path.clone());
483
484                            tracing::info!("✓ MCP connected: {} ({})", project.name, path);
485
486                            // Send confirmation
487                            let _ = send_protocol_message(
488                                &tx,
489                                "registered",
490                                RegisteredPayload { success: true },
491                            );
492
493                            // Broadcast to UI clients with mcp_connected=true
494                            let mut project_info = project.clone();
495                            project_info.mcp_connected = true;
496                            let ui_msg = ProtocolMessage::new(
497                                "project_online",
498                                ProjectOnlinePayload {
499                                    project: project_info,
500                                },
501                            );
502                            state_for_recv
503                                .broadcast_to_ui(&ui_msg.to_json().unwrap())
504                                .await;
505                        },
506                        "pong" => {
507                            // Client responded to our ping - heartbeat confirmed
508                            tracing::trace!("Received pong from MCP client - heartbeat confirmed");
509                        },
510                        "goodbye" => {
511                            // Client is closing connection gracefully
512                            if let Ok(goodbye_payload) =
513                                serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
514                            {
515                                if let Some(reason) = goodbye_payload.reason {
516                                    tracing::info!("MCP client closing connection: {}", reason);
517                                } else {
518                                    tracing::info!("MCP client closing connection gracefully");
519                                }
520                            }
521                            // Break loop to close connection
522                            break;
523                        },
524                        _ => {
525                            tracing::warn!("Unknown message type: {}", parsed_msg.message_type);
526                        },
527                    }
528                },
529                Message::Close(_) => {
530                    tracing::info!("MCP client closed WebSocket");
531                    break;
532                },
533                _ => {},
534            }
535        }
536
537        project_path
538    });
539
540    // Wait for any task to finish
541    tokio::select! {
542        _ = (&mut send_task) => {
543            recv_task.abort();
544            heartbeat_task.abort();
545        }
546        project_path_result = (&mut recv_task) => {
547            send_task.abort();
548            heartbeat_task.abort();
549            if let Ok(Some(path)) = project_path_result {
550                // Clean up connection
551                state.mcp_connections.write().await.remove(&path);
552
553                tracing::info!("MCP disconnected: {}", path);
554
555                // Notify UI clients
556                let ui_msg = ProtocolMessage::new(
557                    "project_offline",
558                    ProjectOfflinePayload { project_path: path.clone() },
559                );
560                state
561                    .broadcast_to_ui(&ui_msg.to_json().unwrap())
562                    .await;
563
564                tracing::info!("MCP disconnected: {}", path);
565            }
566        }
567        _ = (&mut heartbeat_task) => {
568            send_task.abort();
569            recv_task.abort();
570        }
571    }
572}
573
574/// Handle UI WebSocket connections
575pub async fn handle_ui_websocket(
576    ws: WebSocketUpgrade,
577    State(app_state): State<crate::dashboard::server::AppState>,
578) -> impl IntoResponse {
579    ws.on_upgrade(move |socket| handle_ui_socket(socket, app_state))
580}
581
582async fn handle_ui_socket(socket: WebSocket, app_state: crate::dashboard::server::AppState) {
583    let (mut sender, mut receiver) = socket.split();
584    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
585
586    // Spawn task to forward messages from channel to WebSocket
587    let mut send_task = tokio::spawn(async move {
588        while let Some(msg) = rx.recv().await {
589            if sender.send(msg).await.is_err() {
590                break;
591            }
592        }
593    });
594
595    // Protocol v1.0 Compliance: Wait for client to send "hello" first
596    // The "init" message will be sent after receiving "hello" and sending "welcome"
597    // This is handled in the message loop below
598
599    // Register this UI connection
600    let conn = UiConnection {
601        tx: tx.clone(),
602        connected_at: chrono::Utc::now(),
603    };
604    let conn_index = {
605        let mut connections = app_state.ws_state.ui_connections.write().await;
606        connections.push(conn);
607        connections.len() - 1
608    };
609
610    tracing::info!("UI client connected");
611
612    // Clone app_state for use inside recv_task
613    let app_state_for_recv = app_state.clone();
614
615    // Clone tx for heartbeat task
616    let heartbeat_tx = tx.clone();
617
618    // Spawn heartbeat task - send ping every 30 seconds (Protocol v1.0 Section 4.1.3)
619    let mut heartbeat_task = tokio::spawn(async move {
620        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
621        // Skip the first tick (which completes immediately)
622        interval.tick().await;
623
624        loop {
625            interval.tick().await;
626            if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
627                // Connection closed
628                break;
629            }
630            tracing::trace!("Sent heartbeat ping to UI client");
631        }
632    });
633
634    // Handle incoming messages (mostly just keep-alive and pong)
635    let mut recv_task = tokio::spawn(async move {
636        while let Some(Ok(msg)) = receiver.next().await {
637            match msg {
638                Message::Text(text) => {
639                    // Parse protocol message from UI
640                    if let Ok(parsed_msg) =
641                        serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
642                    {
643                        match parsed_msg.message_type.as_str() {
644                            "hello" => {
645                                // Parse hello payload
646                                if let Ok(hello) =
647                                    serde_json::from_value::<HelloPayload>(parsed_msg.payload)
648                                {
649                                    tracing::info!(
650                                        "Received hello from {} client",
651                                        hello.entity_type
652                                    );
653
654                                    // Generate session ID
655                                    let session_id = format!(
656                                        "{}-{}",
657                                        hello.entity_type,
658                                        chrono::Utc::now().timestamp_millis()
659                                    );
660
661                                    // Send welcome response
662                                    let welcome_payload = WelcomePayload {
663                                        session_id,
664                                        capabilities: vec![],
665                                    };
666
667                                    let _ = send_protocol_message(&tx, "welcome", welcome_payload);
668                                    tracing::debug!("Sent welcome message to UI");
669
670                                    // Send init after welcome (protocol-compliant flow)
671                                    // Re-fetch projects in case state changed
672                                    let current_projects = {
673                                        let current_project =
674                                            app_state_for_recv.current_project.read().await;
675                                        let port = app_state_for_recv.port;
676                                        app_state_for_recv
677                                            .ws_state
678                                            .get_online_projects_with_current(
679                                                &current_project.project_name,
680                                                &current_project.project_path,
681                                                &current_project.db_path,
682                                                port,
683                                            )
684                                            .await
685                                    };
686                                    let _ = send_protocol_message(
687                                        &tx,
688                                        "init",
689                                        InitPayload {
690                                            projects: current_projects,
691                                        },
692                                    );
693                                }
694                            },
695                            "pong" => {
696                                tracing::trace!("Received pong from UI");
697                            },
698                            "goodbye" => {
699                                // UI client closing gracefully
700                                if let Ok(goodbye_payload) =
701                                    serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
702                                {
703                                    if let Some(reason) = goodbye_payload.reason {
704                                        tracing::info!("UI client closing: {}", reason);
705                                    } else {
706                                        tracing::info!("UI client closing gracefully");
707                                    }
708                                }
709                                break;
710                            },
711                            _ => {
712                                tracing::trace!(
713                                    "Received from UI: {} ({})",
714                                    parsed_msg.message_type,
715                                    text
716                                );
717                            },
718                        }
719                    } else {
720                        tracing::trace!("Received non-protocol message from UI: {}", text);
721                    }
722                },
723                Message::Pong(_) => {
724                    tracing::trace!("Received WebSocket pong from UI");
725                },
726                Message::Close(_) => {
727                    tracing::info!("UI client closed WebSocket");
728                    break;
729                },
730                _ => {},
731            }
732        }
733    });
734
735    // Wait for any task to finish
736    tokio::select! {
737        _ = (&mut send_task) => {
738            recv_task.abort();
739            heartbeat_task.abort();
740        }
741        _ = (&mut recv_task) => {
742            send_task.abort();
743            heartbeat_task.abort();
744        }
745        _ = (&mut heartbeat_task) => {
746            send_task.abort();
747            recv_task.abort();
748        }
749    }
750
751    // Clean up UI connection
752    app_state
753        .ws_state
754        .ui_connections
755        .write()
756        .await
757        .swap_remove(conn_index);
758    tracing::info!("UI client disconnected");
759}