intent_engine/mcp/
ws_client.rs

1//! WebSocket client for MCP → Dashboard communication
2//!
3//! Handles registration and keep-alive for MCP server instances.
4//!
5//! # Testing Strategy
6//!
7//! This module requires a running Dashboard server for testing.
8//! Tests are located in `tests/ws_client_integration_test.rs` and marked with `#[ignore]`.
9//!
10//! To run integration tests:
11//! ```bash
12//! # Start Dashboard first
13//! ie dashboard start
14//!
15//! # Run integration tests
16//! cargo test --test ws_client_integration_test -- --ignored
17//! ```
18
19use anyhow::{Context, Result};
20use futures_util::{SinkExt, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::path::PathBuf;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio_tungstenite::{connect_async, tungstenite::Message};
26
27/// Protocol version
28const PROTOCOL_VERSION: &str = "1.0";
29
30/// Protocol message wrapper
31#[derive(Debug, Serialize, Deserialize)]
32struct ProtocolMessage<T> {
33    version: String,
34    #[serde(rename = "type")]
35    message_type: String,
36    payload: T,
37    timestamp: String,
38}
39
40impl<T: Serialize> ProtocolMessage<T> {
41    fn new(message_type: impl Into<String>, payload: T) -> Self {
42        Self {
43            version: PROTOCOL_VERSION.to_string(),
44            message_type: message_type.into(),
45            payload,
46            timestamp: chrono::Utc::now().to_rfc3339(),
47        }
48    }
49
50    fn to_json(&self) -> Result<String> {
51        serde_json::to_string(self).map_err(Into::into)
52    }
53}
54
55/// Empty payload for ping/pong messages
56#[derive(Debug, Serialize, Deserialize)]
57struct EmptyPayload {}
58
59/// Payload for registered response
60#[derive(Debug, Serialize, Deserialize)]
61struct RegisteredPayload {
62    success: bool,
63}
64
65/// Payload for goodbye message
66#[derive(Debug, Serialize, Deserialize)]
67struct GoodbyePayload {
68    #[serde(skip_serializing_if = "Option::is_none")]
69    reason: Option<String>,
70}
71
72/// Payload for hello message (client → server)
73#[derive(Debug, Serialize, Deserialize)]
74struct HelloPayload {
75    entity_type: String,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    capabilities: Option<Vec<String>>,
78}
79
80/// Payload for welcome message (server → client)
81#[derive(Debug, Serialize, Deserialize)]
82struct WelcomePayload {
83    session_id: String,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    capabilities: Option<Vec<String>>,
86}
87
88/// Payload for error message (server → client)
89#[derive(Debug, Serialize, Deserialize)]
90struct ErrorPayload {
91    code: String,
92    message: String,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    details: Option<serde_json::Value>,
95}
96
97/// Project information sent to Dashboard
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ProjectInfo {
100    pub path: String,
101    pub name: String,
102    pub db_path: String,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub agent: Option<String>,
105    /// Whether this project has an active MCP connection
106    pub mcp_connected: bool,
107    /// Whether the Dashboard serving this project is online
108    pub is_online: bool,
109}
110
111/// Reconnection delays in seconds (exponential backoff with max)
112const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
113
114/// Start WebSocket client with automatic reconnection
115/// This function runs indefinitely, reconnecting on disconnection
116pub async fn connect_to_dashboard(
117    project_path: PathBuf,
118    db_path: PathBuf,
119    agent: Option<String>,
120    notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
121    dashboard_port: Option<u16>,
122) -> Result<()> {
123    // Validate project path once at the beginning
124    let normalized_project_path = project_path
125        .canonicalize()
126        .unwrap_or_else(|_| project_path.clone());
127
128    let temp_dir = std::env::temp_dir()
129        .canonicalize()
130        .unwrap_or_else(|_| std::env::temp_dir());
131
132    if normalized_project_path.starts_with(&temp_dir) {
133        tracing::warn!(
134            "Skipping Dashboard registration for temporary path: {}",
135            normalized_project_path.display()
136        );
137        return Ok(()); // Silently skip for temp paths
138    }
139
140    let mut attempt = 0;
141
142    // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
143    let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
144
145    // Infinite reconnection loop
146    loop {
147        tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
148
149        match connect_and_run(
150            project_path.clone(),
151            db_path.clone(),
152            agent.clone(),
153            notification_rx.clone(),
154            dashboard_port,
155        )
156        .await
157        {
158            Ok(()) => {
159                // Graceful close - reset attempt counter and retry immediately
160                tracing::info!("Dashboard connection closed gracefully, reconnecting...");
161                attempt = 0;
162                // Small delay before reconnecting
163                tokio::time::sleep(Duration::from_secs(1)).await;
164            },
165            Err(e) => {
166                // Connection error - use exponential backoff
167                tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
168
169                // Calculate delay with exponential backoff
170                let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
171                let base_delay = RECONNECT_DELAYS[delay_index];
172
173                // Add jitter: ±25% random variance
174                let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
175                let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
176                let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
177                let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
178
179                tracing::info!(
180                    "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
181                    delay.as_secs_f64(),
182                    base_delay,
183                    jitter_ms / 1000.0
184                );
185
186                tokio::time::sleep(delay).await;
187                attempt += 1;
188            },
189        }
190    }
191}
192
193/// Internal function: Connect to Dashboard and run until disconnection
194/// Returns Ok(()) on graceful close, Err on connection failure
195async fn connect_and_run(
196    project_path: PathBuf,
197    db_path: PathBuf,
198    agent: Option<String>,
199    notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
200    dashboard_port: Option<u16>,
201) -> Result<()> {
202    // Extract project name from path
203    let project_name = project_path
204        .file_name()
205        .and_then(|n| n.to_str())
206        .unwrap_or("unknown")
207        .to_string();
208
209    // Normalize paths to handle symlinks
210    let normalized_project_path = project_path
211        .canonicalize()
212        .unwrap_or_else(|_| project_path.clone());
213    let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
214
215    // Create project info
216    let project_info = ProjectInfo {
217        path: normalized_project_path.to_string_lossy().to_string(),
218        name: project_name,
219        db_path: normalized_db_path.to_string_lossy().to_string(),
220        agent,
221        mcp_connected: true,
222        is_online: true,
223    };
224
225    // Connect to Dashboard WebSocket
226    let port = dashboard_port.unwrap_or(11391);
227    let url = format!("ws://127.0.0.1:{}/ws/mcp", port);
228    let (ws_stream, _) = connect_async(&url)
229        .await
230        .context("Failed to connect to Dashboard WebSocket")?;
231
232    tracing::debug!("Connected to Dashboard at {}", url);
233
234    let (mut write, mut read) = ws_stream.split();
235
236    // Step 1: Send hello message (Protocol v1.0 handshake)
237    let hello_msg = ProtocolMessage::new(
238        "hello",
239        HelloPayload {
240            entity_type: "mcp_server".to_string(),
241            capabilities: Some(vec![]),
242        },
243    );
244    write
245        .send(Message::Text(hello_msg.to_json()?))
246        .await
247        .context("Failed to send hello message")?;
248    tracing::debug!("Sent hello message");
249
250    // Step 2: Wait for welcome response
251    if let Some(Ok(Message::Text(text))) = read.next().await {
252        match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
253            Ok(msg) if msg.message_type == "welcome" => {
254                tracing::debug!(
255                    "Received welcome from Dashboard (session: {})",
256                    msg.payload.session_id
257                );
258            },
259            Ok(msg) => {
260                tracing::warn!(
261                    "Expected welcome, received: {} (legacy Dashboard?)",
262                    msg.message_type
263                );
264                // Continue anyway for backward compatibility
265            },
266            Err(e) => {
267                tracing::warn!("Failed to parse welcome message: {}", e);
268            },
269        }
270    }
271
272    // Step 3: Send registration message
273    let register_msg = ProtocolMessage::new("register", project_info.clone());
274    let register_json = register_msg.to_json()?;
275    write
276        .send(Message::Text(register_json))
277        .await
278        .context("Failed to send register message")?;
279
280    // Step 4: Wait for registration confirmation
281    if let Some(Ok(Message::Text(text))) = read.next().await {
282        match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
283            Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
284                tracing::debug!("Successfully registered with Dashboard");
285            },
286            Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
287                anyhow::bail!("Dashboard rejected registration");
288            },
289            _ => {
290                tracing::debug!("Unexpected response during registration: {}", text);
291            },
292        }
293    }
294
295    // Spawn read/write task to handle messages and respond to pings
296    // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
297    tokio::spawn(async move {
298        loop {
299            // Handle notification channel if available
300            if let Some(ref rx) = notification_rx {
301                let mut rx_guard = rx.lock().await;
302                tokio::select! {
303                    msg_result = read.next() => {
304                        if let Some(Ok(msg)) = msg_result {
305                        match msg {
306                            Message::Text(text) => {
307                                if let Ok(msg) =
308                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
309                                {
310                                    match msg.message_type.as_str() {
311                                        "ping" => {
312                                            // Dashboard sent ping - respond with pong
313                                            tracing::debug!(
314                                                "Received ping from Dashboard, responding with pong"
315                                            );
316                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
317                                            if let Ok(pong_json) = pong_msg.to_json() {
318                                                if write.send(Message::Text(pong_json)).await.is_err() {
319                                                    tracing::warn!(
320                                                        "Failed to send pong - Dashboard connection lost"
321                                                    );
322                                                    break;
323                                                }
324                                            }
325                                        },
326                                        "error" => {
327                                            // Dashboard sent an error
328                                            if let Ok(error) =
329                                                serde_json::from_value::<ErrorPayload>(msg.payload)
330                                            {
331                                                tracing::error!(
332                                                    "Dashboard error [{}]: {}",
333                                                    error.code,
334                                                    error.message
335                                                );
336                                                if let Some(details) = error.details {
337                                                    tracing::error!("  Details: {}", details);
338                                                }
339
340                                                // Handle critical errors
341                                                match error.code.as_str() {
342                                                    "unsupported_version" => {
343                                                        tracing::error!(
344                                                            "Protocol version mismatch - connection will close"
345                                                        );
346                                                        break;
347                                                    },
348                                                    "invalid_path" => {
349                                                        tracing::error!("Project path rejected by Dashboard");
350                                                        break;
351                                                    },
352                                                    _ => {
353                                                        // Non-critical errors, continue
354                                                    },
355                                                }
356                                            }
357                                        },
358                                        "goodbye" => {
359                                            // Dashboard is closing connection gracefully
360                                            if let Ok(goodbye) =
361                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
362                                            {
363                                                if let Some(reason) = goodbye.reason {
364                                                    tracing::info!("Dashboard closing connection: {}", reason);
365                                                } else {
366                                                    tracing::info!("Dashboard closing connection gracefully");
367                                                }
368                                            }
369                                            break;
370                                        },
371                                        _ => {
372                                            tracing::debug!(
373                                                "Received message from Dashboard: {} ({})",
374                                                msg.message_type,
375                                                text
376                                            );
377                                        },
378                                    }
379                                } else {
380                                    tracing::debug!("Received non-protocol message: {}", text);
381                                }
382                            },
383                            Message::Close(_) => {
384                                tracing::info!("Dashboard closed connection");
385                                break;
386                            },
387                            _ => {}
388                        }
389                        } else {
390                            // None or error - connection closed
391                            tracing::info!("Dashboard WebSocket stream ended");
392                            break;
393                        }
394                    }
395                    notification_result = rx_guard.recv() => {
396                        if let Some(notification) = notification_result {
397                            // Send notification to Dashboard
398                            if let Err(e) = write.send(Message::Text(notification)).await {
399                                tracing::warn!("Failed to send notification to Dashboard: {}", e);
400                                break;
401                            }
402                            tracing::debug!("Sent db_operation notification to Dashboard");
403                        }
404                    }
405                }
406                drop(rx_guard); // Release the lock after select!
407            } else {
408                // No notification channel - only handle WebSocket messages
409                tokio::select! {
410                    msg_result = read.next() => {
411                        if let Some(Ok(msg)) = msg_result {
412                        match msg {
413                            Message::Text(text) => {
414                                if let Ok(msg) =
415                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
416                                {
417                                    match msg.message_type.as_str() {
418                                        "ping" => {
419                                            // Dashboard sent ping - respond with pong
420                                            tracing::debug!(
421                                                "Received ping from Dashboard, responding with pong"
422                                            );
423                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
424                                            if let Ok(pong_json) = pong_msg.to_json() {
425                                                if write.send(Message::Text(pong_json)).await.is_err() {
426                                                    tracing::warn!(
427                                                        "Failed to send pong - Dashboard connection lost"
428                                                    );
429                                                    break;
430                                                }
431                                            }
432                                        },
433                                        "error" => {
434                                            // Dashboard sent an error
435                                            if let Ok(error) =
436                                                serde_json::from_value::<ErrorPayload>(msg.payload)
437                                            {
438                                                tracing::error!(
439                                                    "Dashboard error [{}]: {}",
440                                                    error.code,
441                                                    error.message
442                                                );
443                                                if let Some(details) = error.details {
444                                                    tracing::error!("  Details: {}", details);
445                                                }
446
447                                                // Handle critical errors
448                                                match error.code.as_str() {
449                                                    "unsupported_version" => {
450                                                        tracing::error!(
451                                                            "Protocol version mismatch - connection will close"
452                                                        );
453                                                        break;
454                                                    },
455                                                    "invalid_path" => {
456                                                        tracing::error!("Project path rejected by Dashboard");
457                                                        break;
458                                                    },
459                                                    _ => {
460                                                        // Non-critical errors, continue
461                                                    },
462                                                }
463                                            }
464                                        },
465                                        "goodbye" => {
466                                            // Dashboard is closing connection gracefully
467                                            if let Ok(goodbye) =
468                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
469                                            {
470                                                if let Some(reason) = goodbye.reason {
471                                                    tracing::info!("Dashboard closing connection: {}", reason);
472                                                } else {
473                                                    tracing::info!("Dashboard closing connection gracefully");
474                                                }
475                                            }
476                                            break;
477                                        },
478                                        _ => {
479                                            tracing::debug!(
480                                                "Received message from Dashboard: {} ({})",
481                                                msg.message_type,
482                                                text
483                                            );
484                                        },
485                                    }
486                                } else {
487                                    tracing::debug!("Received non-protocol message: {}", text);
488                                }
489                            },
490                            Message::Close(_) => {
491                                tracing::info!("Dashboard closed connection");
492                                break;
493                            }
494                            _ => {}
495                        }
496                        } else {
497                            // None or error - connection closed
498                            tracing::info!("Dashboard WebSocket stream ended");
499                            break;
500                        }
501                    }
502                }
503            }
504        }
505    });
506
507    Ok(())
508}