intent_engine/mcp/
ws_client.rs

1// WebSocket client for MCP → Dashboard communication
2// Handles registration and keep-alive for MCP server instances
3
4use anyhow::{Context, Result};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11
12/// Protocol version
13const PROTOCOL_VERSION: &str = "1.0";
14
15/// Protocol message wrapper
16#[derive(Debug, Serialize, Deserialize)]
17struct ProtocolMessage<T> {
18    version: String,
19    #[serde(rename = "type")]
20    message_type: String,
21    payload: T,
22    timestamp: String,
23}
24
25impl<T: Serialize> ProtocolMessage<T> {
26    fn new(message_type: impl Into<String>, payload: T) -> Self {
27        Self {
28            version: PROTOCOL_VERSION.to_string(),
29            message_type: message_type.into(),
30            payload,
31            timestamp: chrono::Utc::now().to_rfc3339(),
32        }
33    }
34
35    fn to_json(&self) -> Result<String> {
36        serde_json::to_string(self).map_err(Into::into)
37    }
38}
39
40/// Empty payload for ping/pong messages
41#[derive(Debug, Serialize, Deserialize)]
42struct EmptyPayload {}
43
44/// Payload for registered response
45#[derive(Debug, Serialize, Deserialize)]
46struct RegisteredPayload {
47    success: bool,
48}
49
50/// Payload for goodbye message
51#[derive(Debug, Serialize, Deserialize)]
52struct GoodbyePayload {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    reason: Option<String>,
55}
56
57/// Payload for hello message (client → server)
58#[derive(Debug, Serialize, Deserialize)]
59struct HelloPayload {
60    entity_type: String,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    capabilities: Option<Vec<String>>,
63}
64
65/// Payload for welcome message (server → client)
66#[derive(Debug, Serialize, Deserialize)]
67struct WelcomePayload {
68    session_id: String,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    capabilities: Option<Vec<String>>,
71}
72
73/// Payload for error message (server → client)
74#[derive(Debug, Serialize, Deserialize)]
75struct ErrorPayload {
76    code: String,
77    message: String,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    details: Option<serde_json::Value>,
80}
81
82/// Project information sent to Dashboard
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProjectInfo {
85    pub path: String,
86    pub name: String,
87    pub db_path: String,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub agent: Option<String>,
90    /// Whether this project has an active MCP connection
91    pub mcp_connected: bool,
92    /// Whether the Dashboard serving this project is online
93    pub is_online: bool,
94}
95
96/// Reconnection delays in seconds (exponential backoff with max)
97const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
98
99/// Start WebSocket client with automatic reconnection
100/// This function runs indefinitely, reconnecting on disconnection
101pub async fn connect_to_dashboard(
102    project_path: PathBuf,
103    db_path: PathBuf,
104    agent: Option<String>,
105    notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
106) -> Result<()> {
107    // Validate project path once at the beginning
108    let normalized_project_path = project_path
109        .canonicalize()
110        .unwrap_or_else(|_| project_path.clone());
111
112    let temp_dir = std::env::temp_dir()
113        .canonicalize()
114        .unwrap_or_else(|_| std::env::temp_dir());
115
116    if normalized_project_path.starts_with(&temp_dir) {
117        tracing::warn!(
118            "Skipping Dashboard registration for temporary path: {}",
119            normalized_project_path.display()
120        );
121        return Ok(()); // Silently skip for temp paths
122    }
123
124    let mut attempt = 0;
125
126    // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
127    let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
128
129    // Infinite reconnection loop
130    loop {
131        tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
132
133        match connect_and_run(
134            project_path.clone(),
135            db_path.clone(),
136            agent.clone(),
137            notification_rx.clone(),
138        )
139        .await
140        {
141            Ok(()) => {
142                // Graceful close - reset attempt counter and retry immediately
143                tracing::info!("Dashboard connection closed gracefully, reconnecting...");
144                attempt = 0;
145                // Small delay before reconnecting
146                tokio::time::sleep(Duration::from_secs(1)).await;
147            },
148            Err(e) => {
149                // Connection error - use exponential backoff
150                tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
151
152                // Calculate delay with exponential backoff
153                let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
154                let base_delay = RECONNECT_DELAYS[delay_index];
155
156                // Add jitter: ±25% random variance
157                let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
158                let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
159                let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
160                let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
161
162                tracing::info!(
163                    "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
164                    delay.as_secs_f64(),
165                    base_delay,
166                    jitter_ms / 1000.0
167                );
168
169                tokio::time::sleep(delay).await;
170                attempt += 1;
171            },
172        }
173    }
174}
175
176/// Internal function: Connect to Dashboard and run until disconnection
177/// Returns Ok(()) on graceful close, Err on connection failure
178async fn connect_and_run(
179    project_path: PathBuf,
180    db_path: PathBuf,
181    agent: Option<String>,
182    notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
183) -> Result<()> {
184    // Extract project name from path
185    let project_name = project_path
186        .file_name()
187        .and_then(|n| n.to_str())
188        .unwrap_or("unknown")
189        .to_string();
190
191    // Normalize paths to handle symlinks
192    let normalized_project_path = project_path
193        .canonicalize()
194        .unwrap_or_else(|_| project_path.clone());
195    let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
196
197    // Create project info
198    let project_info = ProjectInfo {
199        path: normalized_project_path.to_string_lossy().to_string(),
200        name: project_name,
201        db_path: normalized_db_path.to_string_lossy().to_string(),
202        agent,
203        mcp_connected: true,
204        is_online: true,
205    };
206
207    // Connect to Dashboard WebSocket
208    let url = "ws://127.0.0.1:11391/ws/mcp";
209    let (ws_stream, _) = connect_async(url)
210        .await
211        .context("Failed to connect to Dashboard WebSocket")?;
212
213    tracing::debug!("Connected to Dashboard at {}", url);
214
215    let (mut write, mut read) = ws_stream.split();
216
217    // Step 1: Send hello message (Protocol v1.0 handshake)
218    let hello_msg = ProtocolMessage::new(
219        "hello",
220        HelloPayload {
221            entity_type: "mcp_server".to_string(),
222            capabilities: Some(vec![]),
223        },
224    );
225    write
226        .send(Message::Text(hello_msg.to_json()?))
227        .await
228        .context("Failed to send hello message")?;
229    tracing::debug!("Sent hello message");
230
231    // Step 2: Wait for welcome response
232    if let Some(Ok(Message::Text(text))) = read.next().await {
233        match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
234            Ok(msg) if msg.message_type == "welcome" => {
235                tracing::debug!(
236                    "Received welcome from Dashboard (session: {})",
237                    msg.payload.session_id
238                );
239            },
240            Ok(msg) => {
241                tracing::warn!(
242                    "Expected welcome, received: {} (legacy Dashboard?)",
243                    msg.message_type
244                );
245                // Continue anyway for backward compatibility
246            },
247            Err(e) => {
248                tracing::warn!("Failed to parse welcome message: {}", e);
249            },
250        }
251    }
252
253    // Step 3: Send registration message
254    let register_msg = ProtocolMessage::new("register", project_info.clone());
255    let register_json = register_msg.to_json()?;
256    write
257        .send(Message::Text(register_json))
258        .await
259        .context("Failed to send register message")?;
260
261    // Step 4: Wait for registration confirmation
262    if let Some(Ok(Message::Text(text))) = read.next().await {
263        match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
264            Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
265                tracing::debug!("Successfully registered with Dashboard");
266            },
267            Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
268                anyhow::bail!("Dashboard rejected registration");
269            },
270            _ => {
271                tracing::debug!("Unexpected response during registration: {}", text);
272            },
273        }
274    }
275
276    // Spawn read/write task to handle messages and respond to pings
277    // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
278    tokio::spawn(async move {
279        loop {
280            // Handle notification channel if available
281            if let Some(ref rx) = notification_rx {
282                let mut rx_guard = rx.lock().await;
283                tokio::select! {
284                    msg_result = read.next() => {
285                        if let Some(Ok(msg)) = msg_result {
286                        match msg {
287                            Message::Text(text) => {
288                                if let Ok(msg) =
289                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
290                                {
291                                    match msg.message_type.as_str() {
292                                        "ping" => {
293                                            // Dashboard sent ping - respond with pong
294                                            tracing::debug!(
295                                                "Received ping from Dashboard, responding with pong"
296                                            );
297                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
298                                            if let Ok(pong_json) = pong_msg.to_json() {
299                                                if write.send(Message::Text(pong_json)).await.is_err() {
300                                                    tracing::warn!(
301                                                        "Failed to send pong - Dashboard connection lost"
302                                                    );
303                                                    break;
304                                                }
305                                            }
306                                        },
307                                        "error" => {
308                                            // Dashboard sent an error
309                                            if let Ok(error) =
310                                                serde_json::from_value::<ErrorPayload>(msg.payload)
311                                            {
312                                                tracing::error!(
313                                                    "Dashboard error [{}]: {}",
314                                                    error.code,
315                                                    error.message
316                                                );
317                                                if let Some(details) = error.details {
318                                                    tracing::error!("  Details: {}", details);
319                                                }
320
321                                                // Handle critical errors
322                                                match error.code.as_str() {
323                                                    "unsupported_version" => {
324                                                        tracing::error!(
325                                                            "Protocol version mismatch - connection will close"
326                                                        );
327                                                        break;
328                                                    },
329                                                    "invalid_path" => {
330                                                        tracing::error!("Project path rejected by Dashboard");
331                                                        break;
332                                                    },
333                                                    _ => {
334                                                        // Non-critical errors, continue
335                                                    },
336                                                }
337                                            }
338                                        },
339                                        "goodbye" => {
340                                            // Dashboard is closing connection gracefully
341                                            if let Ok(goodbye) =
342                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
343                                            {
344                                                if let Some(reason) = goodbye.reason {
345                                                    tracing::info!("Dashboard closing connection: {}", reason);
346                                                } else {
347                                                    tracing::info!("Dashboard closing connection gracefully");
348                                                }
349                                            }
350                                            break;
351                                        },
352                                        _ => {
353                                            tracing::debug!(
354                                                "Received message from Dashboard: {} ({})",
355                                                msg.message_type,
356                                                text
357                                            );
358                                        },
359                                    }
360                                } else {
361                                    tracing::debug!("Received non-protocol message: {}", text);
362                                }
363                            },
364                            Message::Close(_) => {
365                                tracing::info!("Dashboard closed connection");
366                                break;
367                            },
368                            _ => {}
369                        }
370                        } else {
371                            // None or error - connection closed
372                            tracing::info!("Dashboard WebSocket stream ended");
373                            break;
374                        }
375                    }
376                    notification_result = rx_guard.recv() => {
377                        if let Some(notification) = notification_result {
378                            // Send notification to Dashboard
379                            if let Err(e) = write.send(Message::Text(notification)).await {
380                                tracing::warn!("Failed to send notification to Dashboard: {}", e);
381                                break;
382                            }
383                            tracing::debug!("Sent db_operation notification to Dashboard");
384                        }
385                    }
386                }
387                drop(rx_guard); // Release the lock after select!
388            } else {
389                // No notification channel - only handle WebSocket messages
390                tokio::select! {
391                    msg_result = read.next() => {
392                        if let Some(Ok(msg)) = msg_result {
393                        match msg {
394                            Message::Text(text) => {
395                                if let Ok(msg) =
396                                    serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
397                                {
398                                    match msg.message_type.as_str() {
399                                        "ping" => {
400                                            // Dashboard sent ping - respond with pong
401                                            tracing::debug!(
402                                                "Received ping from Dashboard, responding with pong"
403                                            );
404                                            let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
405                                            if let Ok(pong_json) = pong_msg.to_json() {
406                                                if write.send(Message::Text(pong_json)).await.is_err() {
407                                                    tracing::warn!(
408                                                        "Failed to send pong - Dashboard connection lost"
409                                                    );
410                                                    break;
411                                                }
412                                            }
413                                        },
414                                        "error" => {
415                                            // Dashboard sent an error
416                                            if let Ok(error) =
417                                                serde_json::from_value::<ErrorPayload>(msg.payload)
418                                            {
419                                                tracing::error!(
420                                                    "Dashboard error [{}]: {}",
421                                                    error.code,
422                                                    error.message
423                                                );
424                                                if let Some(details) = error.details {
425                                                    tracing::error!("  Details: {}", details);
426                                                }
427
428                                                // Handle critical errors
429                                                match error.code.as_str() {
430                                                    "unsupported_version" => {
431                                                        tracing::error!(
432                                                            "Protocol version mismatch - connection will close"
433                                                        );
434                                                        break;
435                                                    },
436                                                    "invalid_path" => {
437                                                        tracing::error!("Project path rejected by Dashboard");
438                                                        break;
439                                                    },
440                                                    _ => {
441                                                        // Non-critical errors, continue
442                                                    },
443                                                }
444                                            }
445                                        },
446                                        "goodbye" => {
447                                            // Dashboard is closing connection gracefully
448                                            if let Ok(goodbye) =
449                                                serde_json::from_value::<GoodbyePayload>(msg.payload)
450                                            {
451                                                if let Some(reason) = goodbye.reason {
452                                                    tracing::info!("Dashboard closing connection: {}", reason);
453                                                } else {
454                                                    tracing::info!("Dashboard closing connection gracefully");
455                                                }
456                                            }
457                                            break;
458                                        },
459                                        _ => {
460                                            tracing::debug!(
461                                                "Received message from Dashboard: {} ({})",
462                                                msg.message_type,
463                                                text
464                                            );
465                                        },
466                                    }
467                                } else {
468                                    tracing::debug!("Received non-protocol message: {}", text);
469                                }
470                            },
471                            Message::Close(_) => {
472                                tracing::info!("Dashboard closed connection");
473                                break;
474                            }
475                            _ => {}
476                        }
477                        } else {
478                            // None or error - connection closed
479                            tracing::info!("Dashboard WebSocket stream ended");
480                            break;
481                        }
482                    }
483                }
484            }
485        }
486    });
487
488    Ok(())
489}