1pub const PROTOCOL_VERSION: &str = "1.0";
6
7use axum::{
8 extract::{
9 ws::{Message, WebSocket},
10 State, WebSocketUpgrade,
11 },
12 response::IntoResponse,
13};
14use futures_util::{SinkExt, StreamExt};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct ProtocolMessage<T> {
23 pub version: String,
25 #[serde(rename = "type")]
27 pub message_type: String,
28 pub payload: T,
30 pub timestamp: String,
32}
33
34impl<T> ProtocolMessage<T>
35where
36 T: Serialize,
37{
38 pub fn new(message_type: impl Into<String>, payload: T) -> Self {
40 Self {
41 version: PROTOCOL_VERSION.to_string(),
42 message_type: message_type.into(),
43 payload,
44 timestamp: chrono::Utc::now().to_rfc3339(),
45 }
46 }
47
48 pub fn to_json(&self) -> Result<String, serde_json::Error> {
50 serde_json::to_string(self)
51 }
52}
53
54impl<T> ProtocolMessage<T>
55where
56 T: for<'de> Deserialize<'de>,
57{
58 pub fn from_json(json: &str) -> Result<Self, String> {
60 let msg: Self = serde_json::from_str(json)
61 .map_err(|e| format!("Failed to parse protocol message: {}", e))?;
62
63 let expected_major = PROTOCOL_VERSION.split('.').next().unwrap_or("1");
65 let received_major = msg.version.split('.').next().unwrap_or("0");
66
67 if expected_major != received_major {
68 return Err(format!(
69 "Protocol version mismatch: expected {}, got {}",
70 PROTOCOL_VERSION, msg.version
71 ));
72 }
73
74 Ok(msg)
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProjectInfo {
81 pub path: String,
82 pub name: String,
83 pub db_path: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub agent: Option<String>,
86 pub mcp_connected: bool,
88 pub is_online: bool,
90}
91
92#[derive(Debug)]
94pub struct McpConnection {
95 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
96 pub project: ProjectInfo,
97 pub connected_at: chrono::DateTime<chrono::Utc>,
98}
99
100#[derive(Debug)]
102pub struct UiConnection {
103 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
104 pub connected_at: chrono::DateTime<chrono::Utc>,
105}
106
107#[derive(Clone)]
109pub struct WebSocketState {
110 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
112 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
114}
115
116impl Default for WebSocketState {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl WebSocketState {
123 pub fn new() -> Self {
124 Self {
125 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
126 ui_connections: Arc::new(RwLock::new(Vec::new())),
127 }
128 }
129
130 pub async fn broadcast_to_ui(&self, message: &str) {
132 let connections = self.ui_connections.read().await;
133 for conn in connections.iter() {
134 let _ = conn.tx.send(Message::Text(message.to_string()));
135 }
136 }
137
138 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
140 let connections = self.mcp_connections.read().await;
142
143 connections
144 .values()
145 .map(|conn| {
146 let mut project = conn.project.clone();
147 project.mcp_connected = true; project
149 })
150 .collect()
151 }
152
153 pub async fn get_online_projects_with_current(
157 &self,
158 current_project_name: &str,
159 current_project_path: &std::path::Path,
160 current_db_path: &std::path::Path,
161 _port: u16,
162 ) -> Vec<ProjectInfo> {
163 let connections = self.mcp_connections.read().await;
164 let current_path_str = current_project_path.display().to_string();
165
166 let mut projects = Vec::new();
167
168 let current_has_mcp = connections
171 .values()
172 .any(|conn| conn.project.path == current_path_str);
173
174 projects.push(ProjectInfo {
175 name: current_project_name.to_string(),
176 path: current_path_str.clone(),
177 db_path: current_db_path.display().to_string(),
178 agent: None, mcp_connected: current_has_mcp,
180 is_online: true, });
182
183 for conn in connections.values() {
185 if conn.project.path != current_path_str {
186 let mut project = conn.project.clone();
187 project.mcp_connected = true;
188 project.is_online = true; projects.push(project);
190 }
191 }
192
193 projects
194 }
195}
196
197#[derive(Debug, Serialize, Deserialize)]
203pub struct RegisterPayload {
204 pub project: ProjectInfo,
205}
206
207#[derive(Debug, Serialize, Deserialize)]
209pub struct RegisteredPayload {
210 pub success: bool,
211}
212
213#[derive(Debug, Serialize, Deserialize)]
215pub struct EmptyPayload {}
216
217#[derive(Debug, Serialize, Deserialize)]
219pub struct InitPayload {
220 pub projects: Vec<ProjectInfo>,
221}
222
223#[derive(Debug, Serialize, Deserialize)]
225pub struct ProjectOnlinePayload {
226 pub project: ProjectInfo,
227}
228
229#[derive(Debug, Serialize, Deserialize)]
231pub struct ProjectOfflinePayload {
232 pub project_path: String,
233}
234
235#[derive(Debug, Serialize, Deserialize)]
237pub struct HelloPayload {
238 pub entity_type: String,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub capabilities: Option<Vec<String>>,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
247pub struct WelcomePayload {
248 pub capabilities: Vec<String>,
250 pub session_id: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
256pub struct GoodbyePayload {
257 #[serde(skip_serializing_if = "Option::is_none")]
259 pub reason: Option<String>,
260}
261
262#[derive(Debug, Serialize, Deserialize)]
264pub struct ErrorPayload {
265 pub code: String,
267 pub message: String,
269 #[serde(skip_serializing_if = "Option::is_none")]
271 pub details: Option<serde_json::Value>,
272}
273
274pub mod error_codes {
276 pub const UNSUPPORTED_VERSION: &str = "unsupported_version";
277 pub const INVALID_MESSAGE: &str = "invalid_message";
278 pub const INVALID_PATH: &str = "invalid_path";
279 pub const REGISTRATION_FAILED: &str = "registration_failed";
280 pub const INTERNAL_ERROR: &str = "internal_error";
281}
282
283fn send_protocol_message<T: Serialize>(
289 tx: &tokio::sync::mpsc::UnboundedSender<Message>,
290 message_type: &str,
291 payload: T,
292) -> Result<(), String> {
293 let protocol_msg = ProtocolMessage::new(message_type, payload);
294 let json = protocol_msg
295 .to_json()
296 .map_err(|e| format!("Failed to serialize message: {}", e))?;
297
298 tx.send(Message::Text(json))
299 .map_err(|_| "Failed to send message: channel closed".to_string())
300}
301
302pub async fn handle_mcp_websocket(
304 ws: WebSocketUpgrade,
305 State(app_state): State<crate::dashboard::server::AppState>,
306) -> impl IntoResponse {
307 ws.on_upgrade(move |socket| handle_mcp_socket(socket, app_state.ws_state))
308}
309
310async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
311 let (mut sender, mut receiver) = socket.split();
312 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
313
314 let mut send_task = tokio::spawn(async move {
316 while let Some(msg) = rx.recv().await {
317 if sender.send(msg).await.is_err() {
318 break;
319 }
320 }
321 });
322
323 let mut project_path: Option<String> = None;
325 let mut session_welcomed = false; let state_for_recv = state.clone();
329
330 let heartbeat_tx = tx.clone();
332
333 let mut heartbeat_task = tokio::spawn(async move {
335 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
336 interval.tick().await;
338
339 loop {
340 interval.tick().await;
341 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
343 break;
345 }
346 tracing::trace!("Sent heartbeat ping to MCP client");
347 }
348 });
349
350 let mut recv_task = tokio::spawn(async move {
352 while let Some(Ok(msg)) = receiver.next().await {
353 match msg {
354 Message::Text(text) => {
355 let parsed_msg = match ProtocolMessage::<serde_json::Value>::from_json(&text) {
357 Ok(msg) => msg,
358 Err(e) => {
359 tracing::warn!("Protocol error: {}", e);
360
361 let error_code = if e.contains("version mismatch") {
363 error_codes::UNSUPPORTED_VERSION
364 } else {
365 error_codes::INVALID_MESSAGE
366 };
367
368 let error_payload = ErrorPayload {
369 code: error_code.to_string(),
370 message: e.to_string(),
371 details: None,
372 };
373
374 let _ = send_protocol_message(&tx, "error", error_payload);
375 continue;
376 },
377 };
378
379 match parsed_msg.message_type.as_str() {
380 "hello" => {
381 let hello: HelloPayload =
383 match serde_json::from_value(parsed_msg.payload.clone()) {
384 Ok(h) => h,
385 Err(e) => {
386 tracing::warn!("Failed to parse hello payload: {}", e);
387 continue;
388 },
389 };
390
391 tracing::info!("Received hello from {} client", hello.entity_type);
392
393 let session_id = format!(
395 "{}-{}",
396 hello.entity_type,
397 chrono::Utc::now().timestamp_millis()
398 );
399
400 let welcome_payload = WelcomePayload {
402 session_id,
403 capabilities: vec![], };
405
406 if send_protocol_message(&tx, "welcome", welcome_payload).is_ok() {
407 session_welcomed = true;
408 tracing::debug!("Sent welcome message");
409 } else {
410 tracing::error!("Failed to send welcome message");
411 }
412 },
413 "register" => {
414 if !session_welcomed {
416 tracing::warn!(
417 "MCP client registered without hello handshake (legacy client detected)"
418 );
419 }
420
421 let project: ProjectInfo =
423 match serde_json::from_value(parsed_msg.payload.clone()) {
424 Ok(p) => p,
425 Err(e) => {
426 tracing::warn!("Failed to parse register payload: {}", e);
427 continue;
428 },
429 };
430 tracing::info!("MCP registering project: {}", project.name);
431
432 let path = project.path.clone();
433 let project_path_buf = std::path::PathBuf::from(&path);
434
435 let normalized_path = project_path_buf
438 .canonicalize()
439 .unwrap_or_else(|_| project_path_buf.clone());
440
441 let temp_dir = std::env::temp_dir()
443 .canonicalize()
444 .unwrap_or_else(|_| std::env::temp_dir());
445 let is_temp_path = normalized_path.starts_with(&temp_dir);
446
447 if is_temp_path {
448 tracing::warn!(
449 "Rejecting MCP registration for temporary/invalid path: {}",
450 path
451 );
452
453 let error_payload = ErrorPayload {
455 code: error_codes::INVALID_PATH.to_string(),
456 message: "Path is in temporary directory".to_string(),
457 details: Some(serde_json::json!({"path": path})),
458 };
459 let _ = send_protocol_message(&tx, "error", error_payload);
460
461 let _ = send_protocol_message(
463 &tx,
464 "registered",
465 RegisteredPayload { success: false },
466 );
467 continue; }
469
470 let conn = McpConnection {
472 tx: tx.clone(),
473 project: project.clone(),
474 connected_at: chrono::Utc::now(),
475 };
476
477 state_for_recv
478 .mcp_connections
479 .write()
480 .await
481 .insert(path.clone(), conn);
482 project_path = Some(path.clone());
483
484 tracing::info!("✓ MCP connected: {} ({})", project.name, path);
485
486 let _ = send_protocol_message(
488 &tx,
489 "registered",
490 RegisteredPayload { success: true },
491 );
492
493 let mut project_info = project.clone();
495 project_info.mcp_connected = true;
496 let ui_msg = ProtocolMessage::new(
497 "project_online",
498 ProjectOnlinePayload {
499 project: project_info,
500 },
501 );
502 state_for_recv
503 .broadcast_to_ui(&ui_msg.to_json().unwrap())
504 .await;
505 },
506 "pong" => {
507 tracing::trace!("Received pong from MCP client - heartbeat confirmed");
509 },
510 "goodbye" => {
511 if let Ok(goodbye_payload) =
513 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
514 {
515 if let Some(reason) = goodbye_payload.reason {
516 tracing::info!("MCP client closing connection: {}", reason);
517 } else {
518 tracing::info!("MCP client closing connection gracefully");
519 }
520 }
521 break;
523 },
524 _ => {
525 tracing::warn!("Unknown message type: {}", parsed_msg.message_type);
526 },
527 }
528 },
529 Message::Close(_) => {
530 tracing::info!("MCP client closed WebSocket");
531 break;
532 },
533 _ => {},
534 }
535 }
536
537 project_path
538 });
539
540 tokio::select! {
542 _ = (&mut send_task) => {
543 recv_task.abort();
544 heartbeat_task.abort();
545 }
546 project_path_result = (&mut recv_task) => {
547 send_task.abort();
548 heartbeat_task.abort();
549 if let Ok(Some(path)) = project_path_result {
550 state.mcp_connections.write().await.remove(&path);
552
553 tracing::info!("MCP disconnected: {}", path);
554
555 let ui_msg = ProtocolMessage::new(
557 "project_offline",
558 ProjectOfflinePayload { project_path: path.clone() },
559 );
560 state
561 .broadcast_to_ui(&ui_msg.to_json().unwrap())
562 .await;
563
564 tracing::info!("MCP disconnected: {}", path);
565 }
566 }
567 _ = (&mut heartbeat_task) => {
568 send_task.abort();
569 recv_task.abort();
570 }
571 }
572}
573
574pub async fn handle_ui_websocket(
576 ws: WebSocketUpgrade,
577 State(app_state): State<crate::dashboard::server::AppState>,
578) -> impl IntoResponse {
579 ws.on_upgrade(move |socket| handle_ui_socket(socket, app_state))
580}
581
582async fn handle_ui_socket(socket: WebSocket, app_state: crate::dashboard::server::AppState) {
583 let (mut sender, mut receiver) = socket.split();
584 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
585
586 let mut send_task = tokio::spawn(async move {
588 while let Some(msg) = rx.recv().await {
589 if sender.send(msg).await.is_err() {
590 break;
591 }
592 }
593 });
594
595 let conn = UiConnection {
601 tx: tx.clone(),
602 connected_at: chrono::Utc::now(),
603 };
604 let conn_index = {
605 let mut connections = app_state.ws_state.ui_connections.write().await;
606 connections.push(conn);
607 connections.len() - 1
608 };
609
610 tracing::info!("UI client connected");
611
612 let app_state_for_recv = app_state.clone();
614
615 let heartbeat_tx = tx.clone();
617
618 let mut heartbeat_task = tokio::spawn(async move {
620 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
621 interval.tick().await;
623
624 loop {
625 interval.tick().await;
626 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
627 break;
629 }
630 tracing::trace!("Sent heartbeat ping to UI client");
631 }
632 });
633
634 let mut recv_task = tokio::spawn(async move {
636 while let Some(Ok(msg)) = receiver.next().await {
637 match msg {
638 Message::Text(text) => {
639 if let Ok(parsed_msg) =
641 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
642 {
643 match parsed_msg.message_type.as_str() {
644 "hello" => {
645 if let Ok(hello) =
647 serde_json::from_value::<HelloPayload>(parsed_msg.payload)
648 {
649 tracing::info!(
650 "Received hello from {} client",
651 hello.entity_type
652 );
653
654 let session_id = format!(
656 "{}-{}",
657 hello.entity_type,
658 chrono::Utc::now().timestamp_millis()
659 );
660
661 let welcome_payload = WelcomePayload {
663 session_id,
664 capabilities: vec![],
665 };
666
667 let _ = send_protocol_message(&tx, "welcome", welcome_payload);
668 tracing::debug!("Sent welcome message to UI");
669
670 let current_projects = {
673 let current_project =
674 app_state_for_recv.current_project.read().await;
675 let port = app_state_for_recv.port;
676 app_state_for_recv
677 .ws_state
678 .get_online_projects_with_current(
679 ¤t_project.project_name,
680 ¤t_project.project_path,
681 ¤t_project.db_path,
682 port,
683 )
684 .await
685 };
686 let _ = send_protocol_message(
687 &tx,
688 "init",
689 InitPayload {
690 projects: current_projects,
691 },
692 );
693 }
694 },
695 "pong" => {
696 tracing::trace!("Received pong from UI");
697 },
698 "goodbye" => {
699 if let Ok(goodbye_payload) =
701 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
702 {
703 if let Some(reason) = goodbye_payload.reason {
704 tracing::info!("UI client closing: {}", reason);
705 } else {
706 tracing::info!("UI client closing gracefully");
707 }
708 }
709 break;
710 },
711 _ => {
712 tracing::trace!(
713 "Received from UI: {} ({})",
714 parsed_msg.message_type,
715 text
716 );
717 },
718 }
719 } else {
720 tracing::trace!("Received non-protocol message from UI: {}", text);
721 }
722 },
723 Message::Pong(_) => {
724 tracing::trace!("Received WebSocket pong from UI");
725 },
726 Message::Close(_) => {
727 tracing::info!("UI client closed WebSocket");
728 break;
729 },
730 _ => {},
731 }
732 }
733 });
734
735 tokio::select! {
737 _ = (&mut send_task) => {
738 recv_task.abort();
739 heartbeat_task.abort();
740 }
741 _ = (&mut recv_task) => {
742 send_task.abort();
743 heartbeat_task.abort();
744 }
745 _ = (&mut heartbeat_task) => {
746 send_task.abort();
747 recv_task.abort();
748 }
749 }
750
751 app_state
753 .ws_state
754 .ui_connections
755 .write()
756 .await
757 .swap_remove(conn_index);
758 tracing::info!("UI client disconnected");
759}