intent_engine/mcp/
ws_client.rs

1// WebSocket client for MCP → Dashboard communication
2// Handles registration and keep-alive for MCP server instances
3
4use anyhow::{Context, Result};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio_tungstenite::{connect_async, tungstenite::Message};
10
11/// Protocol version
12const PROTOCOL_VERSION: &str = "1.0";
13
14/// Protocol message wrapper
15#[derive(Debug, Serialize, Deserialize)]
16struct ProtocolMessage<T> {
17    version: String,
18    #[serde(rename = "type")]
19    message_type: String,
20    payload: T,
21    timestamp: String,
22}
23
24impl<T: Serialize> ProtocolMessage<T> {
25    fn new(message_type: impl Into<String>, payload: T) -> Self {
26        Self {
27            version: PROTOCOL_VERSION.to_string(),
28            message_type: message_type.into(),
29            payload,
30            timestamp: chrono::Utc::now().to_rfc3339(),
31        }
32    }
33
34    fn to_json(&self) -> Result<String> {
35        serde_json::to_string(self).map_err(Into::into)
36    }
37}
38
39/// Empty payload for ping/pong messages
40#[derive(Debug, Serialize, Deserialize)]
41struct EmptyPayload {}
42
43/// Payload for registered response
44#[derive(Debug, Serialize, Deserialize)]
45struct RegisteredPayload {
46    success: bool,
47}
48
49/// Payload for goodbye message
50#[derive(Debug, Serialize, Deserialize)]
51struct GoodbyePayload {
52    #[serde(skip_serializing_if = "Option::is_none")]
53    reason: Option<String>,
54}
55
56/// Payload for hello message (client → server)
57#[derive(Debug, Serialize, Deserialize)]
58struct HelloPayload {
59    entity_type: String,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    capabilities: Option<Vec<String>>,
62}
63
64/// Payload for welcome message (server → client)
65#[derive(Debug, Serialize, Deserialize)]
66struct WelcomePayload {
67    session_id: String,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    capabilities: Option<Vec<String>>,
70}
71
72/// Payload for error message (server → client)
73#[derive(Debug, Serialize, Deserialize)]
74struct ErrorPayload {
75    code: String,
76    message: String,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    details: Option<serde_json::Value>,
79}
80
81/// Project information sent to Dashboard
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ProjectInfo {
84    pub path: String,
85    pub name: String,
86    pub db_path: String,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub agent: Option<String>,
89}
90
91/// Reconnection delays in seconds (exponential backoff with max)
92const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
93
94/// Start WebSocket client with automatic reconnection
95/// This function runs indefinitely, reconnecting on disconnection
96pub async fn connect_to_dashboard(
97    project_path: PathBuf,
98    db_path: PathBuf,
99    agent: Option<String>,
100) -> Result<()> {
101    // Validate project path once at the beginning
102    let normalized_project_path = project_path
103        .canonicalize()
104        .unwrap_or_else(|_| project_path.clone());
105
106    let temp_dir = std::env::temp_dir()
107        .canonicalize()
108        .unwrap_or_else(|_| std::env::temp_dir());
109
110    if normalized_project_path.starts_with(&temp_dir) {
111        tracing::warn!(
112            "Skipping Dashboard registration for temporary path: {}",
113            normalized_project_path.display()
114        );
115        return Ok(()); // Silently skip for temp paths
116    }
117
118    let mut attempt = 0;
119
120    // Infinite reconnection loop
121    loop {
122        tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
123
124        match connect_and_run(project_path.clone(), db_path.clone(), agent.clone()).await {
125            Ok(()) => {
126                // Graceful close - reset attempt counter and retry immediately
127                tracing::info!("Dashboard connection closed gracefully, reconnecting...");
128                attempt = 0;
129                // Small delay before reconnecting
130                tokio::time::sleep(Duration::from_secs(1)).await;
131            },
132            Err(e) => {
133                // Connection error - use exponential backoff
134                tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
135
136                // Calculate delay with exponential backoff
137                let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
138                let base_delay = RECONNECT_DELAYS[delay_index];
139
140                // Add jitter: ±25% random variance
141                let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
142                let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
143                let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
144                let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
145
146                tracing::info!(
147                    "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
148                    delay.as_secs_f64(),
149                    base_delay,
150                    jitter_ms / 1000.0
151                );
152
153                tokio::time::sleep(delay).await;
154                attempt += 1;
155            },
156        }
157    }
158}
159
160/// Internal function: Connect to Dashboard and run until disconnection
161/// Returns Ok(()) on graceful close, Err on connection failure
162async fn connect_and_run(
163    project_path: PathBuf,
164    db_path: PathBuf,
165    agent: Option<String>,
166) -> Result<()> {
167    // Extract project name from path
168    let project_name = project_path
169        .file_name()
170        .and_then(|n| n.to_str())
171        .unwrap_or("unknown")
172        .to_string();
173
174    // Normalize paths to handle symlinks
175    let normalized_project_path = project_path
176        .canonicalize()
177        .unwrap_or_else(|_| project_path.clone());
178    let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
179
180    // Create project info
181    let project_info = ProjectInfo {
182        path: normalized_project_path.to_string_lossy().to_string(),
183        name: project_name,
184        db_path: normalized_db_path.to_string_lossy().to_string(),
185        agent,
186    };
187
188    // Connect to Dashboard WebSocket
189    let url = "ws://127.0.0.1:11391/ws/mcp";
190    let (ws_stream, _) = connect_async(url)
191        .await
192        .context("Failed to connect to Dashboard WebSocket")?;
193
194    tracing::debug!("Connected to Dashboard at {}", url);
195
196    let (mut write, mut read) = ws_stream.split();
197
198    // Step 1: Send hello message (Protocol v1.0 handshake)
199    let hello_msg = ProtocolMessage::new(
200        "hello",
201        HelloPayload {
202            entity_type: "mcp_server".to_string(),
203            capabilities: Some(vec![]),
204        },
205    );
206    write
207        .send(Message::Text(hello_msg.to_json()?))
208        .await
209        .context("Failed to send hello message")?;
210    tracing::debug!("Sent hello message");
211
212    // Step 2: Wait for welcome response
213    if let Some(Ok(Message::Text(text))) = read.next().await {
214        match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
215            Ok(msg) if msg.message_type == "welcome" => {
216                tracing::debug!(
217                    "Received welcome from Dashboard (session: {})",
218                    msg.payload.session_id
219                );
220            },
221            Ok(msg) => {
222                tracing::warn!(
223                    "Expected welcome, received: {} (legacy Dashboard?)",
224                    msg.message_type
225                );
226                // Continue anyway for backward compatibility
227            },
228            Err(e) => {
229                tracing::warn!("Failed to parse welcome message: {}", e);
230            },
231        }
232    }
233
234    // Step 3: Send registration message
235    let register_msg = ProtocolMessage::new("register", project_info.clone());
236    let register_json = register_msg.to_json()?;
237    write
238        .send(Message::Text(register_json))
239        .await
240        .context("Failed to send register message")?;
241
242    // Step 4: Wait for registration confirmation
243    if let Some(Ok(Message::Text(text))) = read.next().await {
244        match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
245            Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
246                tracing::debug!("Successfully registered with Dashboard");
247            },
248            Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
249                anyhow::bail!("Dashboard rejected registration");
250            },
251            _ => {
252                tracing::debug!("Unexpected response during registration: {}", text);
253            },
254        }
255    }
256
257    // Spawn read/write task to handle messages and respond to pings
258    // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
259    tokio::spawn(async move {
260        while let Some(Ok(msg)) = read.next().await {
261            match msg {
262                Message::Text(text) => {
263                    if let Ok(msg) =
264                        serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
265                    {
266                        match msg.message_type.as_str() {
267                            "ping" => {
268                                // Dashboard sent ping - respond with pong
269                                tracing::debug!(
270                                    "Received ping from Dashboard, responding with pong"
271                                );
272                                let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
273                                if let Ok(pong_json) = pong_msg.to_json() {
274                                    if write.send(Message::Text(pong_json)).await.is_err() {
275                                        tracing::warn!(
276                                            "Failed to send pong - Dashboard connection lost"
277                                        );
278                                        break;
279                                    }
280                                }
281                            },
282                            "error" => {
283                                // Dashboard sent an error
284                                if let Ok(error) =
285                                    serde_json::from_value::<ErrorPayload>(msg.payload)
286                                {
287                                    tracing::error!(
288                                        "Dashboard error [{}]: {}",
289                                        error.code,
290                                        error.message
291                                    );
292                                    if let Some(details) = error.details {
293                                        tracing::error!("  Details: {}", details);
294                                    }
295
296                                    // Handle critical errors
297                                    match error.code.as_str() {
298                                        "unsupported_version" => {
299                                            tracing::error!(
300                                                "Protocol version mismatch - connection will close"
301                                            );
302                                            break;
303                                        },
304                                        "invalid_path" => {
305                                            tracing::error!("Project path rejected by Dashboard");
306                                            break;
307                                        },
308                                        _ => {
309                                            // Non-critical errors, continue
310                                        },
311                                    }
312                                }
313                            },
314                            "goodbye" => {
315                                // Dashboard is closing connection gracefully
316                                if let Ok(goodbye) =
317                                    serde_json::from_value::<GoodbyePayload>(msg.payload)
318                                {
319                                    if let Some(reason) = goodbye.reason {
320                                        tracing::info!("Dashboard closing connection: {}", reason);
321                                    } else {
322                                        tracing::info!("Dashboard closing connection gracefully");
323                                    }
324                                }
325                                break;
326                            },
327                            _ => {
328                                tracing::debug!(
329                                    "Received message from Dashboard: {} ({})",
330                                    msg.message_type,
331                                    text
332                                );
333                            },
334                        }
335                    } else {
336                        tracing::debug!("Received non-protocol message: {}", text);
337                    }
338                },
339                Message::Close(_) => {
340                    tracing::info!("Dashboard closed connection");
341                    break;
342                },
343                _ => {},
344            }
345        }
346    });
347
348    Ok(())
349}