intent_engine/mcp/
ws_client.rs

1// WebSocket client for MCP → Dashboard communication
2// Handles registration and keep-alive for MCP server instances
3
4use anyhow::{Context, Result};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11
12/// Protocol version
13const PROTOCOL_VERSION: &str = "1.0";
14
15/// Protocol message wrapper
16#[derive(Debug, Serialize, Deserialize)]
17struct ProtocolMessage<T> {
18    version: String,
19    #[serde(rename = "type")]
20    message_type: String,
21    payload: T,
22    timestamp: String,
23}
24
25impl<T: Serialize> ProtocolMessage<T> {
26    fn new(message_type: impl Into<String>, payload: T) -> Self {
27        Self {
28            version: PROTOCOL_VERSION.to_string(),
29            message_type: message_type.into(),
30            payload,
31            timestamp: chrono::Utc::now().to_rfc3339(),
32        }
33    }
34
35    fn to_json(&self) -> Result<String> {
36        serde_json::to_string(self).map_err(Into::into)
37    }
38}
39
40/// Empty payload for ping/pong messages
41#[derive(Debug, Serialize, Deserialize)]
42struct EmptyPayload {}
43
44/// Payload for registered response
45#[derive(Debug, Serialize, Deserialize)]
46struct RegisteredPayload {
47    success: bool,
48}
49
50/// Payload for goodbye message
51#[derive(Debug, Serialize, Deserialize)]
52struct GoodbyePayload {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    reason: Option<String>,
55}
56
57/// Payload for hello message (client → server)
58#[derive(Debug, Serialize, Deserialize)]
59struct HelloPayload {
60    entity_type: String,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    capabilities: Option<Vec<String>>,
63}
64
65/// Payload for welcome message (server → client)
66#[derive(Debug, Serialize, Deserialize)]
67struct WelcomePayload {
68    session_id: String,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    capabilities: Option<Vec<String>>,
71}
72
73/// Payload for error message (server → client)
74#[derive(Debug, Serialize, Deserialize)]
75struct ErrorPayload {
76    code: String,
77    message: String,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    details: Option<serde_json::Value>,
80}
81
82/// Project information sent to Dashboard
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProjectInfo {
85    pub path: String,
86    pub name: String,
87    pub db_path: String,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub agent: Option<String>,
90    /// Whether this project has an active MCP connection
91    pub mcp_connected: bool,
92    /// Whether the Dashboard serving this project is online
93    pub is_online: bool,
94}
95
96/// Reconnection delays in seconds (exponential backoff with max)
97const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
98
99/// Start WebSocket client with automatic reconnection
100/// This function runs indefinitely, reconnecting on disconnection
101pub async fn connect_to_dashboard(
102    project_path: PathBuf,
103    db_path: PathBuf,
104    agent: Option<String>,
105    notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
106    dashboard_port: Option<u16>,
107) -> Result<()> {
108    // Validate project path once at the beginning
109    let normalized_project_path = project_path
110        .canonicalize()
111        .unwrap_or_else(|_| project_path.clone());
112
113    let temp_dir = std::env::temp_dir()
114        .canonicalize()
115        .unwrap_or_else(|_| std::env::temp_dir());
116
117    if normalized_project_path.starts_with(&temp_dir) {
118        tracing::warn!(
119            "Skipping Dashboard registration for temporary path: {}",
120            normalized_project_path.display()
121        );
122        return Ok(()); // Silently skip for temp paths
123    }
124
125    let mut attempt = 0;
126
127    // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
128    let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
129
130    // Infinite reconnection loop
131    loop {
132        tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
133
134        match connect_and_run(
135            project_path.clone(),
136            db_path.clone(),
137            agent.clone(),
138            notification_rx.clone(),
139            dashboard_port,
140        )
141        .await
142        {
143            Ok(()) => {
144                // Graceful close - reset attempt counter and retry immediately
145                tracing::info!("Dashboard connection closed gracefully, reconnecting...");
146                attempt = 0;
147                // Small delay before reconnecting
148                tokio::time::sleep(Duration::from_secs(1)).await;
149            },
150            Err(e) => {
151                // Connection error - use exponential backoff
152                tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
153
154                // Calculate delay with exponential backoff
155                let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
156                let base_delay = RECONNECT_DELAYS[delay_index];
157
158                // Add jitter: ±25% random variance
159                let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
160                let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
161                let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
162                let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
163
164                tracing::info!(
165                    "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
166                    delay.as_secs_f64(),
167                    base_delay,
168                    jitter_ms / 1000.0
169                );
170
171                tokio::time::sleep(delay).await;
172                attempt += 1;
173            },
174        }
175    }
176}
177
178/// Internal function: Connect to Dashboard and run until disconnection
179/// Returns Ok(()) on graceful close, Err on connection failure
180async fn connect_and_run(
181    project_path: PathBuf,
182    db_path: PathBuf,
183    agent: Option<String>,
184    notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
185    dashboard_port: Option<u16>,
186) -> Result<()> {
187    // Extract project name from path
188    let project_name = project_path
189        .file_name()
190        .and_then(|n| n.to_str())
191        .unwrap_or("unknown")
192        .to_string();
193
194    // Normalize paths to handle symlinks
195    let normalized_project_path = project_path
196        .canonicalize()
197        .unwrap_or_else(|_| project_path.clone());
198    let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
199
200    // Create project info
201    let project_info = ProjectInfo {
202        path: normalized_project_path.to_string_lossy().to_string(),
203        name: project_name,
204        db_path: normalized_db_path.to_string_lossy().to_string(),
205        agent,
206        mcp_connected: true,
207        is_online: true,
208    };
209
210    // Connect to Dashboard WebSocket
211    let port = dashboard_port.unwrap_or(11391);
212    let url = format!("ws://127.0.0.1:{}/ws/mcp", port);
213    let (ws_stream, _) = connect_async(&url)
214        .await
215        .context("Failed to connect to Dashboard WebSocket")?;
216
217    tracing::debug!("Connected to Dashboard at {}", url);
218
219    let (mut write, mut read) = ws_stream.split();
220
221    // Step 1: Send hello message (Protocol v1.0 handshake)
222    let hello_msg = ProtocolMessage::new(
223        "hello",
224        HelloPayload {
225            entity_type: "mcp_server".to_string(),
226            capabilities: Some(vec![]),
227        },
228    );
229    write
230        .send(Message::Text(hello_msg.to_json()?))
231        .await
232        .context("Failed to send hello message")?;
233    tracing::debug!("Sent hello message");
234
235    // Step 2: Wait for welcome response
236    if let Some(Ok(Message::Text(text))) = read.next().await {
237        match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
238            Ok(msg) if msg.message_type == "welcome" => {
239                tracing::debug!(
240                    "Received welcome from Dashboard (session: {})",
241                    msg.payload.session_id
242                );
243            },
244            Ok(msg) => {
245                tracing::warn!(
246                    "Expected welcome, received: {} (legacy Dashboard?)",
247                    msg.message_type
248                );
249                // Continue anyway for backward compatibility
250            },
251            Err(e) => {
252                tracing::warn!("Failed to parse welcome message: {}", e);
253            },
254        }
255    }
256
257    // Step 3: Send registration message
258    let register_msg = ProtocolMessage::new("register", project_info.clone());
259    let register_json = register_msg.to_json()?;
260    write
261        .send(Message::Text(register_json))
262        .await
263        .context("Failed to send register message")?;
264
265    // Step 4: Wait for registration confirmation
266    if let Some(Ok(Message::Text(text))) = read.next().await {
267        match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
268            Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
269                tracing::debug!("Successfully registered with Dashboard");
270            },
271            Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
272                anyhow::bail!("Dashboard rejected registration");
273            },
274            _ => {
275                tracing::debug!("Unexpected response during registration: {}", text);
276            },
277        }
278    }
279
280    // Spawn read/write task to handle messages and respond to pings
281    // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
282    tokio::spawn(async move {
283        loop {
284            // Handle notification channel if available
285            if let Some(ref rx) = notification_rx {
286                let mut rx_guard = rx.lock().await;
287                tokio::select! {
288                    msg_result = read.next() => {
289                        if let Some(Ok(msg)) = msg_result {
290                        match msg {
291                            Message::Text(text) => {
292                                if let Ok(msg) =
293                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
294                                {
295                                    match msg.message_type.as_str() {
296                                        "ping" => {
297                                            // Dashboard sent ping - respond with pong
298                                            tracing::debug!(
299                                                "Received ping from Dashboard, responding with pong"
300                                            );
301                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
302                                            if let Ok(pong_json) = pong_msg.to_json() {
303                                                if write.send(Message::Text(pong_json)).await.is_err() {
304                                                    tracing::warn!(
305                                                        "Failed to send pong - Dashboard connection lost"
306                                                    );
307                                                    break;
308                                                }
309                                            }
310                                        },
311                                        "error" => {
312                                            // Dashboard sent an error
313                                            if let Ok(error) =
314                                                serde_json::from_value::<ErrorPayload>(msg.payload)
315                                            {
316                                                tracing::error!(
317                                                    "Dashboard error [{}]: {}",
318                                                    error.code,
319                                                    error.message
320                                                );
321                                                if let Some(details) = error.details {
322                                                    tracing::error!("  Details: {}", details);
323                                                }
324
325                                                // Handle critical errors
326                                                match error.code.as_str() {
327                                                    "unsupported_version" => {
328                                                        tracing::error!(
329                                                            "Protocol version mismatch - connection will close"
330                                                        );
331                                                        break;
332                                                    },
333                                                    "invalid_path" => {
334                                                        tracing::error!("Project path rejected by Dashboard");
335                                                        break;
336                                                    },
337                                                    _ => {
338                                                        // Non-critical errors, continue
339                                                    },
340                                                }
341                                            }
342                                        },
343                                        "goodbye" => {
344                                            // Dashboard is closing connection gracefully
345                                            if let Ok(goodbye) =
346                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
347                                            {
348                                                if let Some(reason) = goodbye.reason {
349                                                    tracing::info!("Dashboard closing connection: {}", reason);
350                                                } else {
351                                                    tracing::info!("Dashboard closing connection gracefully");
352                                                }
353                                            }
354                                            break;
355                                        },
356                                        _ => {
357                                            tracing::debug!(
358                                                "Received message from Dashboard: {} ({})",
359                                                msg.message_type,
360                                                text
361                                            );
362                                        },
363                                    }
364                                } else {
365                                    tracing::debug!("Received non-protocol message: {}", text);
366                                }
367                            },
368                            Message::Close(_) => {
369                                tracing::info!("Dashboard closed connection");
370                                break;
371                            },
372                            _ => {}
373                        }
374                        } else {
375                            // None or error - connection closed
376                            tracing::info!("Dashboard WebSocket stream ended");
377                            break;
378                        }
379                    }
380                    notification_result = rx_guard.recv() => {
381                        if let Some(notification) = notification_result {
382                            // Send notification to Dashboard
383                            if let Err(e) = write.send(Message::Text(notification)).await {
384                                tracing::warn!("Failed to send notification to Dashboard: {}", e);
385                                break;
386                            }
387                            tracing::debug!("Sent db_operation notification to Dashboard");
388                        }
389                    }
390                }
391                drop(rx_guard); // Release the lock after select!
392            } else {
393                // No notification channel - only handle WebSocket messages
394                tokio::select! {
395                    msg_result = read.next() => {
396                        if let Some(Ok(msg)) = msg_result {
397                        match msg {
398                            Message::Text(text) => {
399                                if let Ok(msg) =
400                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
401                                {
402                                    match msg.message_type.as_str() {
403                                        "ping" => {
404                                            // Dashboard sent ping - respond with pong
405                                            tracing::debug!(
406                                                "Received ping from Dashboard, responding with pong"
407                                            );
408                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
409                                            if let Ok(pong_json) = pong_msg.to_json() {
410                                                if write.send(Message::Text(pong_json)).await.is_err() {
411                                                    tracing::warn!(
412                                                        "Failed to send pong - Dashboard connection lost"
413                                                    );
414                                                    break;
415                                                }
416                                            }
417                                        },
418                                        "error" => {
419                                            // Dashboard sent an error
420                                            if let Ok(error) =
421                                                serde_json::from_value::<ErrorPayload>(msg.payload)
422                                            {
423                                                tracing::error!(
424                                                    "Dashboard error [{}]: {}",
425                                                    error.code,
426                                                    error.message
427                                                );
428                                                if let Some(details) = error.details {
429                                                    tracing::error!("  Details: {}", details);
430                                                }
431
432                                                // Handle critical errors
433                                                match error.code.as_str() {
434                                                    "unsupported_version" => {
435                                                        tracing::error!(
436                                                            "Protocol version mismatch - connection will close"
437                                                        );
438                                                        break;
439                                                    },
440                                                    "invalid_path" => {
441                                                        tracing::error!("Project path rejected by Dashboard");
442                                                        break;
443                                                    },
444                                                    _ => {
445                                                        // Non-critical errors, continue
446                                                    },
447                                                }
448                                            }
449                                        },
450                                        "goodbye" => {
451                                            // Dashboard is closing connection gracefully
452                                            if let Ok(goodbye) =
453                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
454                                            {
455                                                if let Some(reason) = goodbye.reason {
456                                                    tracing::info!("Dashboard closing connection: {}", reason);
457                                                } else {
458                                                    tracing::info!("Dashboard closing connection gracefully");
459                                                }
460                                            }
461                                            break;
462                                        },
463                                        _ => {
464                                            tracing::debug!(
465                                                "Received message from Dashboard: {} ({})",
466                                                msg.message_type,
467                                                text
468                                            );
469                                        },
470                                    }
471                                } else {
472                                    tracing::debug!("Received non-protocol message: {}", text);
473                                }
474                            },
475                            Message::Close(_) => {
476                                tracing::info!("Dashboard closed connection");
477                                break;
478                            }
479                            _ => {}
480                        }
481                        } else {
482                            // None or error - connection closed
483                            tracing::info!("Dashboard WebSocket stream ended");
484                            break;
485                        }
486                    }
487                }
488            }
489        }
490    });
491
492    Ok(())
493}