intent_engine/mcp/ws_client.rs
1// WebSocket client for MCP → Dashboard communication
2// Handles registration and keep-alive for MCP server instances
3
4use anyhow::{Context, Result};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11
12/// Protocol version
13const PROTOCOL_VERSION: &str = "1.0";
14
15/// Protocol message wrapper
16#[derive(Debug, Serialize, Deserialize)]
17struct ProtocolMessage<T> {
18 version: String,
19 #[serde(rename = "type")]
20 message_type: String,
21 payload: T,
22 timestamp: String,
23}
24
25impl<T: Serialize> ProtocolMessage<T> {
26 fn new(message_type: impl Into<String>, payload: T) -> Self {
27 Self {
28 version: PROTOCOL_VERSION.to_string(),
29 message_type: message_type.into(),
30 payload,
31 timestamp: chrono::Utc::now().to_rfc3339(),
32 }
33 }
34
35 fn to_json(&self) -> Result<String> {
36 serde_json::to_string(self).map_err(Into::into)
37 }
38}
39
40/// Empty payload for ping/pong messages
41#[derive(Debug, Serialize, Deserialize)]
42struct EmptyPayload {}
43
44/// Payload for registered response
45#[derive(Debug, Serialize, Deserialize)]
46struct RegisteredPayload {
47 success: bool,
48}
49
50/// Payload for goodbye message
51#[derive(Debug, Serialize, Deserialize)]
52struct GoodbyePayload {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 reason: Option<String>,
55}
56
57/// Payload for hello message (client → server)
58#[derive(Debug, Serialize, Deserialize)]
59struct HelloPayload {
60 entity_type: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 capabilities: Option<Vec<String>>,
63}
64
65/// Payload for welcome message (server → client)
66#[derive(Debug, Serialize, Deserialize)]
67struct WelcomePayload {
68 session_id: String,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 capabilities: Option<Vec<String>>,
71}
72
73/// Payload for error message (server → client)
74#[derive(Debug, Serialize, Deserialize)]
75struct ErrorPayload {
76 code: String,
77 message: String,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 details: Option<serde_json::Value>,
80}
81
82/// Project information sent to Dashboard
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProjectInfo {
85 pub path: String,
86 pub name: String,
87 pub db_path: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub agent: Option<String>,
90 /// Whether this project has an active MCP connection
91 pub mcp_connected: bool,
92 /// Whether the Dashboard serving this project is online
93 pub is_online: bool,
94}
95
96/// Reconnection delays in seconds (exponential backoff with max)
97const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
98
99/// Start WebSocket client with automatic reconnection
100/// This function runs indefinitely, reconnecting on disconnection
101pub async fn connect_to_dashboard(
102 project_path: PathBuf,
103 db_path: PathBuf,
104 agent: Option<String>,
105 notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
106) -> Result<()> {
107 // Validate project path once at the beginning
108 let normalized_project_path = project_path
109 .canonicalize()
110 .unwrap_or_else(|_| project_path.clone());
111
112 let temp_dir = std::env::temp_dir()
113 .canonicalize()
114 .unwrap_or_else(|_| std::env::temp_dir());
115
116 if normalized_project_path.starts_with(&temp_dir) {
117 tracing::warn!(
118 "Skipping Dashboard registration for temporary path: {}",
119 normalized_project_path.display()
120 );
121 return Ok(()); // Silently skip for temp paths
122 }
123
124 let mut attempt = 0;
125
126 // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
127 let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
128
129 // Infinite reconnection loop
130 loop {
131 tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
132
133 match connect_and_run(
134 project_path.clone(),
135 db_path.clone(),
136 agent.clone(),
137 notification_rx.clone(),
138 )
139 .await
140 {
141 Ok(()) => {
142 // Graceful close - reset attempt counter and retry immediately
143 tracing::info!("Dashboard connection closed gracefully, reconnecting...");
144 attempt = 0;
145 // Small delay before reconnecting
146 tokio::time::sleep(Duration::from_secs(1)).await;
147 },
148 Err(e) => {
149 // Connection error - use exponential backoff
150 tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
151
152 // Calculate delay with exponential backoff
153 let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
154 let base_delay = RECONNECT_DELAYS[delay_index];
155
156 // Add jitter: ±25% random variance
157 let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
158 let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
159 let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
160 let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
161
162 tracing::info!(
163 "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
164 delay.as_secs_f64(),
165 base_delay,
166 jitter_ms / 1000.0
167 );
168
169 tokio::time::sleep(delay).await;
170 attempt += 1;
171 },
172 }
173 }
174}
175
176/// Internal function: Connect to Dashboard and run until disconnection
177/// Returns Ok(()) on graceful close, Err on connection failure
178async fn connect_and_run(
179 project_path: PathBuf,
180 db_path: PathBuf,
181 agent: Option<String>,
182 notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
183) -> Result<()> {
184 // Extract project name from path
185 let project_name = project_path
186 .file_name()
187 .and_then(|n| n.to_str())
188 .unwrap_or("unknown")
189 .to_string();
190
191 // Normalize paths to handle symlinks
192 let normalized_project_path = project_path
193 .canonicalize()
194 .unwrap_or_else(|_| project_path.clone());
195 let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
196
197 // Create project info
198 let project_info = ProjectInfo {
199 path: normalized_project_path.to_string_lossy().to_string(),
200 name: project_name,
201 db_path: normalized_db_path.to_string_lossy().to_string(),
202 agent,
203 mcp_connected: true,
204 is_online: true,
205 };
206
207 // Connect to Dashboard WebSocket
208 let url = "ws://127.0.0.1:11391/ws/mcp";
209 let (ws_stream, _) = connect_async(url)
210 .await
211 .context("Failed to connect to Dashboard WebSocket")?;
212
213 tracing::debug!("Connected to Dashboard at {}", url);
214
215 let (mut write, mut read) = ws_stream.split();
216
217 // Step 1: Send hello message (Protocol v1.0 handshake)
218 let hello_msg = ProtocolMessage::new(
219 "hello",
220 HelloPayload {
221 entity_type: "mcp_server".to_string(),
222 capabilities: Some(vec![]),
223 },
224 );
225 write
226 .send(Message::Text(hello_msg.to_json()?))
227 .await
228 .context("Failed to send hello message")?;
229 tracing::debug!("Sent hello message");
230
231 // Step 2: Wait for welcome response
232 if let Some(Ok(Message::Text(text))) = read.next().await {
233 match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
234 Ok(msg) if msg.message_type == "welcome" => {
235 tracing::debug!(
236 "Received welcome from Dashboard (session: {})",
237 msg.payload.session_id
238 );
239 },
240 Ok(msg) => {
241 tracing::warn!(
242 "Expected welcome, received: {} (legacy Dashboard?)",
243 msg.message_type
244 );
245 // Continue anyway for backward compatibility
246 },
247 Err(e) => {
248 tracing::warn!("Failed to parse welcome message: {}", e);
249 },
250 }
251 }
252
253 // Step 3: Send registration message
254 let register_msg = ProtocolMessage::new("register", project_info.clone());
255 let register_json = register_msg.to_json()?;
256 write
257 .send(Message::Text(register_json))
258 .await
259 .context("Failed to send register message")?;
260
261 // Step 4: Wait for registration confirmation
262 if let Some(Ok(Message::Text(text))) = read.next().await {
263 match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
264 Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
265 tracing::debug!("Successfully registered with Dashboard");
266 },
267 Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
268 anyhow::bail!("Dashboard rejected registration");
269 },
270 _ => {
271 tracing::debug!("Unexpected response during registration: {}", text);
272 },
273 }
274 }
275
276 // Spawn read/write task to handle messages and respond to pings
277 // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
278 tokio::spawn(async move {
279 loop {
280 // Handle notification channel if available
281 if let Some(ref rx) = notification_rx {
282 let mut rx_guard = rx.lock().await;
283 tokio::select! {
284 msg_result = read.next() => {
285 if let Some(Ok(msg)) = msg_result {
286 match msg {
287 Message::Text(text) => {
288 if let Ok(msg) =
289 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
290 {
291 match msg.message_type.as_str() {
292 "ping" => {
293 // Dashboard sent ping - respond with pong
294 tracing::debug!(
295 "Received ping from Dashboard, responding with pong"
296 );
297 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
298 if let Ok(pong_json) = pong_msg.to_json() {
299 if write.send(Message::Text(pong_json)).await.is_err() {
300 tracing::warn!(
301 "Failed to send pong - Dashboard connection lost"
302 );
303 break;
304 }
305 }
306 },
307 "error" => {
308 // Dashboard sent an error
309 if let Ok(error) =
310 serde_json::from_value::<ErrorPayload>(msg.payload)
311 {
312 tracing::error!(
313 "Dashboard error [{}]: {}",
314 error.code,
315 error.message
316 );
317 if let Some(details) = error.details {
318 tracing::error!(" Details: {}", details);
319 }
320
321 // Handle critical errors
322 match error.code.as_str() {
323 "unsupported_version" => {
324 tracing::error!(
325 "Protocol version mismatch - connection will close"
326 );
327 break;
328 },
329 "invalid_path" => {
330 tracing::error!("Project path rejected by Dashboard");
331 break;
332 },
333 _ => {
334 // Non-critical errors, continue
335 },
336 }
337 }
338 },
339 "goodbye" => {
340 // Dashboard is closing connection gracefully
341 if let Ok(goodbye) =
342 serde_json::from_value::<GoodbyePayload>(msg.payload)
343 {
344 if let Some(reason) = goodbye.reason {
345 tracing::info!("Dashboard closing connection: {}", reason);
346 } else {
347 tracing::info!("Dashboard closing connection gracefully");
348 }
349 }
350 break;
351 },
352 _ => {
353 tracing::debug!(
354 "Received message from Dashboard: {} ({})",
355 msg.message_type,
356 text
357 );
358 },
359 }
360 } else {
361 tracing::debug!("Received non-protocol message: {}", text);
362 }
363 },
364 Message::Close(_) => {
365 tracing::info!("Dashboard closed connection");
366 break;
367 },
368 _ => {}
369 }
370 } else {
371 // None or error - connection closed
372 tracing::info!("Dashboard WebSocket stream ended");
373 break;
374 }
375 }
376 notification_result = rx_guard.recv() => {
377 if let Some(notification) = notification_result {
378 // Send notification to Dashboard
379 if let Err(e) = write.send(Message::Text(notification)).await {
380 tracing::warn!("Failed to send notification to Dashboard: {}", e);
381 break;
382 }
383 tracing::debug!("Sent db_operation notification to Dashboard");
384 }
385 }
386 }
387 drop(rx_guard); // Release the lock after select!
388 } else {
389 // No notification channel - only handle WebSocket messages
390 tokio::select! {
391 msg_result = read.next() => {
392 if let Some(Ok(msg)) = msg_result {
393 match msg {
394 Message::Text(text) => {
395 if let Ok(msg) =
396 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
397 {
398 match msg.message_type.as_str() {
399 "ping" => {
400 // Dashboard sent ping - respond with pong
401 tracing::debug!(
402 "Received ping from Dashboard, responding with pong"
403 );
404 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
405 if let Ok(pong_json) = pong_msg.to_json() {
406 if write.send(Message::Text(pong_json)).await.is_err() {
407 tracing::warn!(
408 "Failed to send pong - Dashboard connection lost"
409 );
410 break;
411 }
412 }
413 },
414 "error" => {
415 // Dashboard sent an error
416 if let Ok(error) =
417 serde_json::from_value::<ErrorPayload>(msg.payload)
418 {
419 tracing::error!(
420 "Dashboard error [{}]: {}",
421 error.code,
422 error.message
423 );
424 if let Some(details) = error.details {
425 tracing::error!(" Details: {}", details);
426 }
427
428 // Handle critical errors
429 match error.code.as_str() {
430 "unsupported_version" => {
431 tracing::error!(
432 "Protocol version mismatch - connection will close"
433 );
434 break;
435 },
436 "invalid_path" => {
437 tracing::error!("Project path rejected by Dashboard");
438 break;
439 },
440 _ => {
441 // Non-critical errors, continue
442 },
443 }
444 }
445 },
446 "goodbye" => {
447 // Dashboard is closing connection gracefully
448 if let Ok(goodbye) =
449 serde_json::from_value::<GoodbyePayload>(msg.payload)
450 {
451 if let Some(reason) = goodbye.reason {
452 tracing::info!("Dashboard closing connection: {}", reason);
453 } else {
454 tracing::info!("Dashboard closing connection gracefully");
455 }
456 }
457 break;
458 },
459 _ => {
460 tracing::debug!(
461 "Received message from Dashboard: {} ({})",
462 msg.message_type,
463 text
464 );
465 },
466 }
467 } else {
468 tracing::debug!("Received non-protocol message: {}", text);
469 }
470 },
471 Message::Close(_) => {
472 tracing::info!("Dashboard closed connection");
473 break;
474 }
475 _ => {}
476 }
477 } else {
478 // None or error - connection closed
479 tracing::info!("Dashboard WebSocket stream ended");
480 break;
481 }
482 }
483 }
484 }
485 }
486 });
487
488 Ok(())
489}