intent_engine/mcp/ws_client.rs
1// WebSocket client for MCP → Dashboard communication
2// Handles registration and keep-alive for MCP server instances
3
4use anyhow::{Context, Result};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11
12/// Protocol version
13const PROTOCOL_VERSION: &str = "1.0";
14
15/// Protocol message wrapper
16#[derive(Debug, Serialize, Deserialize)]
17struct ProtocolMessage<T> {
18 version: String,
19 #[serde(rename = "type")]
20 message_type: String,
21 payload: T,
22 timestamp: String,
23}
24
25impl<T: Serialize> ProtocolMessage<T> {
26 fn new(message_type: impl Into<String>, payload: T) -> Self {
27 Self {
28 version: PROTOCOL_VERSION.to_string(),
29 message_type: message_type.into(),
30 payload,
31 timestamp: chrono::Utc::now().to_rfc3339(),
32 }
33 }
34
35 fn to_json(&self) -> Result<String> {
36 serde_json::to_string(self).map_err(Into::into)
37 }
38}
39
40/// Empty payload for ping/pong messages
41#[derive(Debug, Serialize, Deserialize)]
42struct EmptyPayload {}
43
44/// Payload for registered response
45#[derive(Debug, Serialize, Deserialize)]
46struct RegisteredPayload {
47 success: bool,
48}
49
50/// Payload for goodbye message
51#[derive(Debug, Serialize, Deserialize)]
52struct GoodbyePayload {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 reason: Option<String>,
55}
56
57/// Payload for hello message (client → server)
58#[derive(Debug, Serialize, Deserialize)]
59struct HelloPayload {
60 entity_type: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 capabilities: Option<Vec<String>>,
63}
64
65/// Payload for welcome message (server → client)
66#[derive(Debug, Serialize, Deserialize)]
67struct WelcomePayload {
68 session_id: String,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 capabilities: Option<Vec<String>>,
71}
72
73/// Payload for error message (server → client)
74#[derive(Debug, Serialize, Deserialize)]
75struct ErrorPayload {
76 code: String,
77 message: String,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 details: Option<serde_json::Value>,
80}
81
82/// Project information sent to Dashboard
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProjectInfo {
85 pub path: String,
86 pub name: String,
87 pub db_path: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub agent: Option<String>,
90 /// Whether this project has an active MCP connection
91 pub mcp_connected: bool,
92 /// Whether the Dashboard serving this project is online
93 pub is_online: bool,
94}
95
96/// Reconnection delays in seconds (exponential backoff with max)
97const RECONNECT_DELAYS: &[u64] = &[1, 2, 4, 8, 16, 32];
98
99/// Start WebSocket client with automatic reconnection
100/// This function runs indefinitely, reconnecting on disconnection
101pub async fn connect_to_dashboard(
102 project_path: PathBuf,
103 db_path: PathBuf,
104 agent: Option<String>,
105 notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
106 dashboard_port: Option<u16>,
107) -> Result<()> {
108 // Validate project path once at the beginning
109 let normalized_project_path = project_path
110 .canonicalize()
111 .unwrap_or_else(|_| project_path.clone());
112
113 let temp_dir = std::env::temp_dir()
114 .canonicalize()
115 .unwrap_or_else(|_| std::env::temp_dir());
116
117 if normalized_project_path.starts_with(&temp_dir) {
118 tracing::warn!(
119 "Skipping Dashboard registration for temporary path: {}",
120 normalized_project_path.display()
121 );
122 return Ok(()); // Silently skip for temp paths
123 }
124
125 let mut attempt = 0;
126
127 // Convert notification_rx to Option<Arc<Mutex<>>> for sharing across reconnections
128 let notification_rx = notification_rx.map(|rx| Arc::new(tokio::sync::Mutex::new(rx)));
129
130 // Infinite reconnection loop
131 loop {
132 tracing::info!("Connecting to Dashboard (attempt {})...", attempt + 1);
133
134 match connect_and_run(
135 project_path.clone(),
136 db_path.clone(),
137 agent.clone(),
138 notification_rx.clone(),
139 dashboard_port,
140 )
141 .await
142 {
143 Ok(()) => {
144 // Graceful close - reset attempt counter and retry immediately
145 tracing::info!("Dashboard connection closed gracefully, reconnecting...");
146 attempt = 0;
147 // Small delay before reconnecting
148 tokio::time::sleep(Duration::from_secs(1)).await;
149 },
150 Err(e) => {
151 // Connection error - use exponential backoff
152 tracing::warn!("Dashboard connection failed: {}. Retrying...", e);
153
154 // Calculate delay with exponential backoff
155 let delay_index = std::cmp::min(attempt, RECONNECT_DELAYS.len() - 1);
156 let base_delay = RECONNECT_DELAYS[delay_index];
157
158 // Add jitter: ±25% random variance
159 let jitter_factor = rand::random::<f64>() * 2.0 - 1.0; // Range: -1.0 to 1.0
160 let jitter_ms = (base_delay * 1000) as f64 * 0.25 * jitter_factor;
161 let delay_ms = (base_delay * 1000) as f64 + jitter_ms;
162 let delay = Duration::from_millis(delay_ms.max(0.0) as u64);
163
164 tracing::info!(
165 "Waiting {:.1}s before retry (base: {}s + jitter: {:.1}s)",
166 delay.as_secs_f64(),
167 base_delay,
168 jitter_ms / 1000.0
169 );
170
171 tokio::time::sleep(delay).await;
172 attempt += 1;
173 },
174 }
175 }
176}
177
178/// Internal function: Connect to Dashboard and run until disconnection
179/// Returns Ok(()) on graceful close, Err on connection failure
180async fn connect_and_run(
181 project_path: PathBuf,
182 db_path: PathBuf,
183 agent: Option<String>,
184 notification_rx: Option<Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<String>>>>,
185 dashboard_port: Option<u16>,
186) -> Result<()> {
187 // Extract project name from path
188 let project_name = project_path
189 .file_name()
190 .and_then(|n| n.to_str())
191 .unwrap_or("unknown")
192 .to_string();
193
194 // Normalize paths to handle symlinks
195 let normalized_project_path = project_path
196 .canonicalize()
197 .unwrap_or_else(|_| project_path.clone());
198 let normalized_db_path = db_path.canonicalize().unwrap_or_else(|_| db_path.clone());
199
200 // Create project info
201 let project_info = ProjectInfo {
202 path: normalized_project_path.to_string_lossy().to_string(),
203 name: project_name,
204 db_path: normalized_db_path.to_string_lossy().to_string(),
205 agent,
206 mcp_connected: true,
207 is_online: true,
208 };
209
210 // Connect to Dashboard WebSocket
211 let port = dashboard_port.unwrap_or(11391);
212 let url = format!("ws://127.0.0.1:{}/ws/mcp", port);
213 let (ws_stream, _) = connect_async(&url)
214 .await
215 .context("Failed to connect to Dashboard WebSocket")?;
216
217 tracing::debug!("Connected to Dashboard at {}", url);
218
219 let (mut write, mut read) = ws_stream.split();
220
221 // Step 1: Send hello message (Protocol v1.0 handshake)
222 let hello_msg = ProtocolMessage::new(
223 "hello",
224 HelloPayload {
225 entity_type: "mcp_server".to_string(),
226 capabilities: Some(vec![]),
227 },
228 );
229 write
230 .send(Message::Text(hello_msg.to_json()?))
231 .await
232 .context("Failed to send hello message")?;
233 tracing::debug!("Sent hello message");
234
235 // Step 2: Wait for welcome response
236 if let Some(Ok(Message::Text(text))) = read.next().await {
237 match serde_json::from_str::<ProtocolMessage<WelcomePayload>>(&text) {
238 Ok(msg) if msg.message_type == "welcome" => {
239 tracing::debug!(
240 "Received welcome from Dashboard (session: {})",
241 msg.payload.session_id
242 );
243 },
244 Ok(msg) => {
245 tracing::warn!(
246 "Expected welcome, received: {} (legacy Dashboard?)",
247 msg.message_type
248 );
249 // Continue anyway for backward compatibility
250 },
251 Err(e) => {
252 tracing::warn!("Failed to parse welcome message: {}", e);
253 },
254 }
255 }
256
257 // Step 3: Send registration message
258 let register_msg = ProtocolMessage::new("register", project_info.clone());
259 let register_json = register_msg.to_json()?;
260 write
261 .send(Message::Text(register_json))
262 .await
263 .context("Failed to send register message")?;
264
265 // Step 4: Wait for registration confirmation
266 if let Some(Ok(Message::Text(text))) = read.next().await {
267 match serde_json::from_str::<ProtocolMessage<RegisteredPayload>>(&text) {
268 Ok(msg) if msg.message_type == "registered" && msg.payload.success => {
269 tracing::debug!("Successfully registered with Dashboard");
270 },
271 Ok(msg) if msg.message_type == "registered" && !msg.payload.success => {
272 anyhow::bail!("Dashboard rejected registration");
273 },
274 _ => {
275 tracing::debug!("Unexpected response during registration: {}", text);
276 },
277 }
278 }
279
280 // Spawn read/write task to handle messages and respond to pings
281 // Protocol v1.0 Section 4.1.3: Dashboard sends ping, client responds with pong
282 tokio::spawn(async move {
283 loop {
284 // Handle notification channel if available
285 if let Some(ref rx) = notification_rx {
286 let mut rx_guard = rx.lock().await;
287 tokio::select! {
288 msg_result = read.next() => {
289 if let Some(Ok(msg)) = msg_result {
290 match msg {
291 Message::Text(text) => {
292 if let Ok(msg) =
293 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
294 {
295 match msg.message_type.as_str() {
296 "ping" => {
297 // Dashboard sent ping - respond with pong
298 tracing::debug!(
299 "Received ping from Dashboard, responding with pong"
300 );
301 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
302 if let Ok(pong_json) = pong_msg.to_json() {
303 if write.send(Message::Text(pong_json)).await.is_err() {
304 tracing::warn!(
305 "Failed to send pong - Dashboard connection lost"
306 );
307 break;
308 }
309 }
310 },
311 "error" => {
312 // Dashboard sent an error
313 if let Ok(error) =
314 serde_json::from_value::<ErrorPayload>(msg.payload)
315 {
316 tracing::error!(
317 "Dashboard error [{}]: {}",
318 error.code,
319 error.message
320 );
321 if let Some(details) = error.details {
322 tracing::error!(" Details: {}", details);
323 }
324
325 // Handle critical errors
326 match error.code.as_str() {
327 "unsupported_version" => {
328 tracing::error!(
329 "Protocol version mismatch - connection will close"
330 );
331 break;
332 },
333 "invalid_path" => {
334 tracing::error!("Project path rejected by Dashboard");
335 break;
336 },
337 _ => {
338 // Non-critical errors, continue
339 },
340 }
341 }
342 },
343 "goodbye" => {
344 // Dashboard is closing connection gracefully
345 if let Ok(goodbye) =
346 serde_json::from_value::<GoodbyePayload>(msg.payload)
347 {
348 if let Some(reason) = goodbye.reason {
349 tracing::info!("Dashboard closing connection: {}", reason);
350 } else {
351 tracing::info!("Dashboard closing connection gracefully");
352 }
353 }
354 break;
355 },
356 _ => {
357 tracing::debug!(
358 "Received message from Dashboard: {} ({})",
359 msg.message_type,
360 text
361 );
362 },
363 }
364 } else {
365 tracing::debug!("Received non-protocol message: {}", text);
366 }
367 },
368 Message::Close(_) => {
369 tracing::info!("Dashboard closed connection");
370 break;
371 },
372 _ => {}
373 }
374 } else {
375 // None or error - connection closed
376 tracing::info!("Dashboard WebSocket stream ended");
377 break;
378 }
379 }
380 notification_result = rx_guard.recv() => {
381 if let Some(notification) = notification_result {
382 // Send notification to Dashboard
383 if let Err(e) = write.send(Message::Text(notification)).await {
384 tracing::warn!("Failed to send notification to Dashboard: {}", e);
385 break;
386 }
387 tracing::debug!("Sent db_operation notification to Dashboard");
388 }
389 }
390 }
391 drop(rx_guard); // Release the lock after select!
392 } else {
393 // No notification channel - only handle WebSocket messages
394 tokio::select! {
395 msg_result = read.next() => {
396 if let Some(Ok(msg)) = msg_result {
397 match msg {
398 Message::Text(text) => {
399 if let Ok(msg) =
400 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
401 {
402 match msg.message_type.as_str() {
403 "ping" => {
404 // Dashboard sent ping - respond with pong
405 tracing::debug!(
406 "Received ping from Dashboard, responding with pong"
407 );
408 let pong_msg = ProtocolMessage::new("pong", EmptyPayload {});
409 if let Ok(pong_json) = pong_msg.to_json() {
410 if write.send(Message::Text(pong_json)).await.is_err() {
411 tracing::warn!(
412 "Failed to send pong - Dashboard connection lost"
413 );
414 break;
415 }
416 }
417 },
418 "error" => {
419 // Dashboard sent an error
420 if let Ok(error) =
421 serde_json::from_value::<ErrorPayload>(msg.payload)
422 {
423 tracing::error!(
424 "Dashboard error [{}]: {}",
425 error.code,
426 error.message
427 );
428 if let Some(details) = error.details {
429 tracing::error!(" Details: {}", details);
430 }
431
432 // Handle critical errors
433 match error.code.as_str() {
434 "unsupported_version" => {
435 tracing::error!(
436 "Protocol version mismatch - connection will close"
437 );
438 break;
439 },
440 "invalid_path" => {
441 tracing::error!("Project path rejected by Dashboard");
442 break;
443 },
444 _ => {
445 // Non-critical errors, continue
446 },
447 }
448 }
449 },
450 "goodbye" => {
451 // Dashboard is closing connection gracefully
452 if let Ok(goodbye) =
453 serde_json::from_value::<GoodbyePayload>(msg.payload)
454 {
455 if let Some(reason) = goodbye.reason {
456 tracing::info!("Dashboard closing connection: {}", reason);
457 } else {
458 tracing::info!("Dashboard closing connection gracefully");
459 }
460 }
461 break;
462 },
463 _ => {
464 tracing::debug!(
465 "Received message from Dashboard: {} ({})",
466 msg.message_type,
467 text
468 );
469 },
470 }
471 } else {
472 tracing::debug!("Received non-protocol message: {}", text);
473 }
474 },
475 Message::Close(_) => {
476 tracing::info!("Dashboard closed connection");
477 break;
478 }
479 _ => {}
480 }
481 } else {
482 // None or error - connection closed
483 tracing::info!("Dashboard WebSocket stream ended");
484 break;
485 }
486 }
487 }
488 }
489 }
490 });
491
492 Ok(())
493}