intent_engine/mcp/ws_client.rs
1//! WebSocket client for MCP → Dashboard communication
2//!
3//! Handles registration and keep-alive for MCP server instances.
4//!
5//! # Testing Strategy
6//!
7//! This module requires a running Dashboard server for testing.
8//! Tests are located in `tests/ws_client_integration_test.rs` and marked with `#[ignore]`.
9//!
10//! To run integration tests:
11//! ```bash
12//! # Start Dashboard first
13//! ie dashboard start
14//!
15//! # Run integration tests
16//! cargo test --test ws_client_integration_test -- --ignored
17//! ```
18
19use anyhow::{Context, Result};
20use futures_util::{SinkExt, StreamExt};
21use serde::{Deserialize, Serialize};
22use std::path::PathBuf;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio_tungstenite::{connect_async, tungstenite::Message};
26
27/// Protocol version
28const PROTOCOL_VERSION: &str = "1.0";
29
30/// Protocol message wrapper
31#[derive(Debug, Serialize, Deserialize)]
32struct ProtocolMessage<T> {
33 version: String,
34 #[serde(rename = "type")]
35 message_type: String,
36 payload: T,
37 timestamp: String,
38}
39
40impl<T: Serialize> ProtocolMessage<T> {
41 fn new(message_type: impl Into<String>, payload: T) -> Self {
42 Self {
43 version: PROTOCOL_VERSION.to_string(),
44 message_type: message_type.into(),
45 payload,
46 timestamp: chrono::Utc::now().to_rfc3339(),
47 }
48 }
49
50 fn to_json(&self) -> Result<String> {
51 serde_json::to_string(self).map_err(Into::into)
52 }
53}
54
55/// Empty payload for ping/pong messages
56#[derive(Debug, Serialize, Deserialize)]
57struct EmptyPayload {}
58
59/// Payload for registered response
60#[derive(Debug, Serialize, Deserialize)]
61struct RegisteredPayload {
62 success: bool,
63}
64
65/// Payload for goodbye message
66#[derive(Debug, Serialize, Deserialize)]
67struct GoodbyePayload {
68 #[serde(skip_serializing_if = "Option::is_none")]
69 reason: Option<String>,
70}
71
72/// Payload for hello message (client → server)
73#[derive(Debug, Serialize, Deserialize)]
74struct HelloPayload {
75 entity_type: String,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 capabilities: Option<Vec<String>>,
78}
79
80/// Payload for welcome message (server → client)
81#[derive(Debug, Serialize, Deserialize)]
82struct WelcomePayload {
83 session_id: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 capabilities: Option<Vec<String>>,
86}
87
88/// Payload for error message (server → client)
89#[derive(Debug, Serialize, Deserialize)]
90struct ErrorPayload {
91 code: String,
92 message: String,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 details: Option<serde_json::Value>,
95}
96
97/// Project information sent to Dashboard
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ProjectInfo {
100 pub path: String,
101 pub name: String,
102 pub db_path: String,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub agent: Option<String>,
105 /// Whether this project has an active MCP connection
106 pub mcp_connected: bool,
107 /// Whether the Dashboard serving this project is online
108 pub is_online: bool,
109}
110
111/// Reconnection delays in seconds (exponential backoff with max)
112const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
113
114/// Start WebSocket client with automatic reconnection
115/// This function runs indefinitely, reconnecting on disconnection
116pub async fn connect_to_dashboard(
117 project_path: PathBuf,
118 db_path: PathBuf,
119 agent: Option<String>,
120 notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
121 dashboard_port: Option<u16>,
122) -> Result<()> {
123 // Validate project path once at the beginning
124 let normalized_project_path = project_path
125 .canonicalize()
126 .unwrap_or_else(|_| project_path.clone());
127
128 let temp_dir = std::env::temp_dir()
129 .canonicalize()
130 .unwrap_or_else(|_| std::env::temp_dir());
131
132 if normalized_project_path.starts_with(&temp_dir) {
133 tracing::warn!(
134 "Skipping Dashboard registration for temporary path: {}",
135 normalized_project_path.display()
136 );
137 return Ok(()); // Silently skip for temp paths
138 }
139
140 let mut attempt = 0;
141
142 // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
143 let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
144
145 // Infinite reconnection loop
146 loop {
147 tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
148
149 match connect_and_run(
150 project_path.clone(),
151 db_path.clone(),
152 agent.clone(),
153 notification_rx.clone(),
154 dashboard_port,
155 )
156 .await
157 {
158 Ok(()) => {
159 // Graceful close - reset attempt counter and retry immediately
160 tracing::info!("Dashboard connection closed gracefully, reconnecting...");
161 attempt = 0;
162 // Small delay before reconnecting
163 tokio::time::sleep(Duration::from_secs(1)).await;
164 },
165 Err(e) => {
166 // Connection error - use exponential backoff
167 tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
168
169 // Calculate delay with exponential backoff
170 let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
171 let base_delay = RECONNECT_DELAYS[delay_index];
172
173 // Add jitter: ±25% random variance
174 let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
175 let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
176 let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
177 let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
178
179 tracing::info!(
180 "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
181 delay.as_secs_f64(),
182 base_delay,
183 jitter_ms / 1000.0
184 );
185
186 tokio::time::sleep(delay).await;
187 attempt += 1;
188 },
189 }
190 }
191}
192
193/// Internal function: Connect to Dashboard and run until disconnection
194/// Returns Ok(()) on graceful close, Err on connection failure
195async fn connect_and_run(
196 project_path: PathBuf,
197 db_path: PathBuf,
198 agent: Option<String>,
199 notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
200 dashboard_port: Option<u16>,
201) -> Result<()> {
202 // Extract project name from path
203 let project_name = project_path
204 .file_name()
205 .and_then(|n| n.to_str())
206 .unwrap_or("unknown")
207 .to_string();
208
209 // Normalize paths to handle symlinks
210 let normalized_project_path = project_path
211 .canonicalize()
212 .unwrap_or_else(|_| project_path.clone());
213 let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
214
215 // Create project info
216 let project_info = ProjectInfo {
217 path: normalized_project_path.to_string_lossy().to_string(),
218 name: project_name,
219 db_path: normalized_db_path.to_string_lossy().to_string(),
220 agent,
221 mcp_connected: true,
222 is_online: true,
223 };
224
225 // Connect to Dashboard WebSocket
226 let port = dashboard_port.unwrap_or(11391);
227 let url = format!("ws://127.0.0.1:{}/ws/mcp", port);
228 let (ws_stream, _) = connect_async(&url)
229 .await
230 .context("Failed to connect to Dashboard WebSocket")?;
231
232 tracing::debug!("Connected to Dashboard at {}", url);
233
234 let (mut write, mut read) = ws_stream.split();
235
236 // Step 1: Send hello message (Protocol v1.0 handshake)
237 let hello_msg = ProtocolMessage::new(
238 "hello",
239 HelloPayload {
240 entity_type: "mcp_server".to_string(),
241 capabilities: Some(vec![]),
242 },
243 );
244 write
245 .send(Message::Text(hello_msg.to_json()?))
246 .await
247 .context("Failed to send hello message")?;
248 tracing::debug!("Sent hello message");
249
250 // Step 2: Wait for welcome response
251 if let Some(Ok(Message::Text(text))) = read.next().await {
252 match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
253 Ok(msg) if msg.message_type == "welcome" => {
254 tracing::debug!(
255 "Received welcome from Dashboard (session: {})",
256 msg.payload.session_id
257 );
258 },
259 Ok(msg) => {
260 tracing::warn!(
261 "Expected welcome, received: {} (legacy Dashboard?)",
262 msg.message_type
263 );
264 // Continue anyway for backward compatibility
265 },
266 Err(e) => {
267 tracing::warn!("Failed to parse welcome message: {}", e);
268 },
269 }
270 }
271
272 // Step 3: Send registration message
273 let register_msg = ProtocolMessage::new("register", project_info.clone());
274 let register_json = register_msg.to_json()?;
275 write
276 .send(Message::Text(register_json))
277 .await
278 .context("Failed to send register message")?;
279
280 // Step 4: Wait for registration confirmation
281 if let Some(Ok(Message::Text(text))) = read.next().await {
282 match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
283 Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
284 tracing::debug!("Successfully registered with Dashboard");
285 },
286 Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
287 anyhow::bail!("Dashboard rejected registration");
288 },
289 _ => {
290 tracing::debug!("Unexpected response during registration: {}", text);
291 },
292 }
293 }
294
295 // Spawn read/write task to handle messages and respond to pings
296 // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
297 tokio::spawn(async move {
298 loop {
299 // Handle notification channel if available
300 if let Some(ref rx) = notification_rx {
301 let mut rx_guard = rx.lock().await;
302 tokio::select! {
303 msg_result = read.next() => {
304 if let Some(Ok(msg)) = msg_result {
305 match msg {
306 Message::Text(text) => {
307 if let Ok(msg) =
308 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
309 {
310 match msg.message_type.as_str() {
311 "ping" => {
312 // Dashboard sent ping - respond with pong
313 tracing::debug!(
314 "Received ping from Dashboard, responding with pong"
315 );
316 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
317 if let Ok(pong_json) = pong_msg.to_json() {
318 if write.send(Message::Text(pong_json)).await.is_err() {
319 tracing::warn!(
320 "Failed to send pong - Dashboard connection lost"
321 );
322 break;
323 }
324 }
325 },
326 "error" => {
327 // Dashboard sent an error
328 if let Ok(error) =
329 serde_json::from_value::<ErrorPayload>(msg.payload)
330 {
331 tracing::error!(
332 "Dashboard error [{}]: {}",
333 error.code,
334 error.message
335 );
336 if let Some(details) = error.details {
337 tracing::error!(" Details: {}", details);
338 }
339
340 // Handle critical errors
341 match error.code.as_str() {
342 "unsupported_version" => {
343 tracing::error!(
344 "Protocol version mismatch - connection will close"
345 );
346 break;
347 },
348 "invalid_path" => {
349 tracing::error!("Project path rejected by Dashboard");
350 break;
351 },
352 _ => {
353 // Non-critical errors, continue
354 },
355 }
356 }
357 },
358 "goodbye" => {
359 // Dashboard is closing connection gracefully
360 if let Ok(goodbye) =
361 serde_json::from_value::<GoodbyePayload>(msg.payload)
362 {
363 if let Some(reason) = goodbye.reason {
364 tracing::info!("Dashboard closing connection: {}", reason);
365 } else {
366 tracing::info!("Dashboard closing connection gracefully");
367 }
368 }
369 break;
370 },
371 _ => {
372 tracing::debug!(
373 "Received message from Dashboard: {} ({})",
374 msg.message_type,
375 text
376 );
377 },
378 }
379 } else {
380 tracing::debug!("Received non-protocol message: {}", text);
381 }
382 },
383 Message::Close(_) => {
384 tracing::info!("Dashboard closed connection");
385 break;
386 },
387 _ => {}
388 }
389 } else {
390 // None or error - connection closed
391 tracing::info!("Dashboard WebSocket stream ended");
392 break;
393 }
394 }
395 notification_result = rx_guard.recv() => {
396 if let Some(notification) = notification_result {
397 // Send notification to Dashboard
398 if let Err(e) = write.send(Message::Text(notification)).await {
399 tracing::warn!("Failed to send notification to Dashboard: {}", e);
400 break;
401 }
402 tracing::debug!("Sent db_operation notification to Dashboard");
403 }
404 }
405 }
406 drop(rx_guard); // Release the lock after select!
407 } else {
408 // No notification channel - only handle WebSocket messages
409 tokio::select! {
410 msg_result = read.next() => {
411 if let Some(Ok(msg)) = msg_result {
412 match msg {
413 Message::Text(text) => {
414 if let Ok(msg) =
415 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
416 {
417 match msg.message_type.as_str() {
418 "ping" => {
419 // Dashboard sent ping - respond with pong
420 tracing::debug!(
421 "Received ping from Dashboard, responding with pong"
422 );
423 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
424 if let Ok(pong_json) = pong_msg.to_json() {
425 if write.send(Message::Text(pong_json)).await.is_err() {
426 tracing::warn!(
427 "Failed to send pong - Dashboard connection lost"
428 );
429 break;
430 }
431 }
432 },
433 "error" => {
434 // Dashboard sent an error
435 if let Ok(error) =
436 serde_json::from_value::<ErrorPayload>(msg.payload)
437 {
438 tracing::error!(
439 "Dashboard error [{}]: {}",
440 error.code,
441 error.message
442 );
443 if let Some(details) = error.details {
444 tracing::error!(" Details: {}", details);
445 }
446
447 // Handle critical errors
448 match error.code.as_str() {
449 "unsupported_version" => {
450 tracing::error!(
451 "Protocol version mismatch - connection will close"
452 );
453 break;
454 },
455 "invalid_path" => {
456 tracing::error!("Project path rejected by Dashboard");
457 break;
458 },
459 _ => {
460 // Non-critical errors, continue
461 },
462 }
463 }
464 },
465 "goodbye" => {
466 // Dashboard is closing connection gracefully
467 if let Ok(goodbye) =
468 serde_json::from_value::<GoodbyePayload>(msg.payload)
469 {
470 if let Some(reason) = goodbye.reason {
471 tracing::info!("Dashboard closing connection: {}", reason);
472 } else {
473 tracing::info!("Dashboard closing connection gracefully");
474 }
475 }
476 break;
477 },
478 _ => {
479 tracing::debug!(
480 "Received message from Dashboard: {} ({})",
481 msg.message_type,
482 text
483 );
484 },
485 }
486 } else {
487 tracing::debug!("Received non-protocol message: {}", text);
488 }
489 },
490 Message::Close(_) => {
491 tracing::info!("Dashboard closed connection");
492 break;
493 }
494 _ => {}
495 }
496 } else {
497 // None or error - connection closed
498 tracing::info!("Dashboard WebSocket stream ended");
499 break;
500 }
501 }
502 }
503 }
504 }
505 });
506
507 Ok(())
508}