1pub const PROTOCOL_VERSION: &str = "1.0";
6
7use axum::{
8 extract::{
9 ws::{Message, WebSocket},
10 State, WebSocketUpgrade,
11 },
12 response::IntoResponse,
13};
14use futures_util::{SinkExt, StreamExt};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct ProtocolMessage<T> {
23 pub version: String,
25 #[serde(rename = "type")]
27 pub message_type: String,
28 pub payload: T,
30 pub timestamp: String,
32}
33
34impl<T> ProtocolMessage<T>
35where
36 T: Serialize,
37{
38 pub fn new(message_type: impl Into<String>, payload: T) -> Self {
40 Self {
41 version: PROTOCOL_VERSION.to_string(),
42 message_type: message_type.into(),
43 payload,
44 timestamp: chrono::Utc::now().to_rfc3339(),
45 }
46 }
47
48 pub fn to_json(&self) -> Result<String, serde_json::Error> {
50 serde_json::to_string(self)
51 }
52}
53
54impl<T> ProtocolMessage<T>
55where
56 T: for<'de> Deserialize<'de>,
57{
58 pub fn from_json(json: &str) -> Result<Self, String> {
60 let msg: Self = serde_json::from_str(json)
61 .map_err(|e| format!("Failed to parse protocol message: {}", e))?;
62
63 let expected_major = PROTOCOL_VERSION.split('.').next().unwrap_or("1");
65 let received_major = msg.version.split('.').next().unwrap_or("0");
66
67 if expected_major != received_major {
68 return Err(format!(
69 "Protocol version mismatch: expected {}, got {}",
70 PROTOCOL_VERSION, msg.version
71 ));
72 }
73
74 Ok(msg)
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProjectInfo {
81 pub path: String,
82 pub name: String,
83 pub db_path: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub agent: Option<String>,
86 pub mcp_connected: bool,
88 pub is_online: bool,
90}
91
92#[derive(Debug)]
94pub struct McpConnection {
95 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
96 pub project: ProjectInfo,
97 pub connected_at: chrono::DateTime<chrono::Utc>,
98}
99
100#[derive(Debug)]
102pub struct UiConnection {
103 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
104 pub connected_at: chrono::DateTime<chrono::Utc>,
105}
106
107#[derive(Clone)]
109pub struct WebSocketState {
110 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
112 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
114}
115
116impl Default for WebSocketState {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl WebSocketState {
123 pub fn new() -> Self {
124 Self {
125 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
126 ui_connections: Arc::new(RwLock::new(Vec::new())),
127 }
128 }
129
130 pub async fn broadcast_to_ui(&self, message: &str) {
132 let connections = self.ui_connections.read().await;
133 for conn in connections.iter() {
134 let _ = conn.tx.send(Message::Text(message.to_string()));
135 }
136 }
137
138 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
140 let connections = self.mcp_connections.read().await;
142
143 connections
144 .values()
145 .map(|conn| {
146 let mut project = conn.project.clone();
147 project.mcp_connected = true; project
149 })
150 .collect()
151 }
152
153 pub async fn get_online_projects_with_current(
157 &self,
158 current_project_name: &str,
159 current_project_path: &std::path::Path,
160 current_db_path: &std::path::Path,
161 host_project: &ProjectInfo,
162 _port: u16,
163 ) -> Vec<ProjectInfo> {
164 let connections = self.mcp_connections.read().await;
165 let current_path_str = current_project_path.display().to_string();
166
167 let mut projects = Vec::new();
168
169 let current_has_mcp = connections
172 .values()
173 .any(|conn| conn.project.path == current_path_str);
174
175 let is_host_current = current_path_str == host_project.path;
177
178 let is_current_online = is_host_current || current_has_mcp;
181
182 projects.push(ProjectInfo {
183 name: current_project_name.to_string(),
184 path: current_path_str.clone(),
185 db_path: current_db_path.display().to_string(),
186 agent: None, mcp_connected: current_has_mcp,
188 is_online: is_current_online,
189 });
190
191 if !is_host_current {
194 let host_has_mcp = connections
195 .values()
196 .any(|conn| conn.project.path == host_project.path);
197
198 let mut host = host_project.clone();
199 host.mcp_connected = host_has_mcp;
200 host.is_online = true; projects.push(host);
202 }
203
204 for conn in connections.values() {
206 if conn.project.path != current_path_str && conn.project.path != host_project.path {
207 let mut project = conn.project.clone();
208 project.mcp_connected = true;
209 project.is_online = true; projects.push(project);
211 }
212 }
213
214 projects
215 }
216}
217
218#[derive(Debug, Serialize, Deserialize)]
224pub struct RegisterPayload {
225 pub project: ProjectInfo,
226}
227
228#[derive(Debug, Serialize, Deserialize)]
230pub struct RegisteredPayload {
231 pub success: bool,
232}
233
234#[derive(Debug, Serialize, Deserialize)]
236pub struct EmptyPayload {}
237
238#[derive(Debug, Serialize, Deserialize)]
240pub struct InitPayload {
241 pub projects: Vec<ProjectInfo>,
242}
243
244#[derive(Debug, Serialize, Deserialize)]
246pub struct ProjectOnlinePayload {
247 pub project: ProjectInfo,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
252pub struct ProjectOfflinePayload {
253 pub project_path: String,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
258pub struct HelloPayload {
259 pub entity_type: String,
261 #[serde(skip_serializing_if = "Option::is_none")]
263 pub capabilities: Option<Vec<String>>,
264}
265
266#[derive(Debug, Serialize, Deserialize)]
268pub struct WelcomePayload {
269 pub capabilities: Vec<String>,
271 pub session_id: String,
273}
274
275#[derive(Debug, Serialize, Deserialize)]
277pub struct GoodbyePayload {
278 #[serde(skip_serializing_if = "Option::is_none")]
280 pub reason: Option<String>,
281}
282
283#[derive(Debug, Serialize, Deserialize)]
285pub struct ErrorPayload {
286 pub code: String,
288 pub message: String,
290 #[serde(skip_serializing_if = "Option::is_none")]
292 pub details: Option<serde_json::Value>,
293}
294
295pub mod error_codes {
297 pub const UNSUPPORTED_VERSION: &str = "unsupported_version";
298 pub const INVALID_MESSAGE: &str = "invalid_message";
299 pub const INVALID_PATH: &str = "invalid_path";
300 pub const REGISTRATION_FAILED: &str = "registration_failed";
301 pub const INTERNAL_ERROR: &str = "internal_error";
302}
303
304#[derive(Debug, Serialize, Deserialize, Clone)]
306pub struct DatabaseOperationPayload {
307 pub operation: String,
309
310 pub entity: String,
312
313 pub affected_ids: Vec<i64>,
315
316 #[serde(skip_serializing_if = "Option::is_none")]
319 pub data: Option<serde_json::Value>,
320
321 pub project_path: String,
323}
324
325impl DatabaseOperationPayload {
326 pub fn new(
328 operation: impl Into<String>,
329 entity: impl Into<String>,
330 affected_ids: Vec<i64>,
331 data: Option<serde_json::Value>,
332 project_path: impl Into<String>,
333 ) -> Self {
334 Self {
335 operation: operation.into(),
336 entity: entity.into(),
337 affected_ids,
338 data,
339 project_path: project_path.into(),
340 }
341 }
342
343 pub fn task_created(
345 task_id: i64,
346 task_data: serde_json::Value,
347 project_path: impl Into<String>,
348 ) -> Self {
349 Self::new(
350 "create",
351 "task",
352 vec![task_id],
353 Some(task_data),
354 project_path,
355 )
356 }
357
358 pub fn task_updated(
360 task_id: i64,
361 task_data: serde_json::Value,
362 project_path: impl Into<String>,
363 ) -> Self {
364 Self::new(
365 "update",
366 "task",
367 vec![task_id],
368 Some(task_data),
369 project_path,
370 )
371 }
372
373 pub fn task_deleted(task_id: i64, project_path: impl Into<String>) -> Self {
375 Self::new("delete", "task", vec![task_id], None, project_path)
376 }
377
378 pub fn task_read(task_id: i64, project_path: impl Into<String>) -> Self {
380 Self::new("read", "task", vec![task_id], None, project_path)
381 }
382
383 pub fn event_created(
385 event_id: i64,
386 event_data: serde_json::Value,
387 project_path: impl Into<String>,
388 ) -> Self {
389 Self::new(
390 "create",
391 "event",
392 vec![event_id],
393 Some(event_data),
394 project_path,
395 )
396 }
397
398 pub fn event_updated(
400 event_id: i64,
401 event_data: serde_json::Value,
402 project_path: impl Into<String>,
403 ) -> Self {
404 Self::new(
405 "update",
406 "event",
407 vec![event_id],
408 Some(event_data),
409 project_path,
410 )
411 }
412
413 pub fn event_deleted(event_id: i64, project_path: impl Into<String>) -> Self {
415 Self::new("delete", "event", vec![event_id], None, project_path)
416 }
417}
418
419fn send_protocol_message<T: Serialize>(
425 tx: &tokio::sync::mpsc::UnboundedSender<Message>,
426 message_type: &str,
427 payload: T,
428) -> Result<(), String> {
429 let protocol_msg = ProtocolMessage::new(message_type, payload);
430 let json = protocol_msg
431 .to_json()
432 .map_err(|e| format!("Failed to serialize message: {}", e))?;
433
434 tx.send(Message::Text(json))
435 .map_err(|_| "Failed to send message: channel closed".to_string())
436}
437
438pub async fn handle_mcp_websocket(
440 ws: WebSocketUpgrade,
441 State(app_state): State<crate::dashboard::server::AppState>,
442) -> impl IntoResponse {
443 ws.on_upgrade(move |socket| handle_mcp_socket(socket, app_state.ws_state))
444}
445
446async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
447 let (mut sender, mut receiver) = socket.split();
448 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
449
450 let mut send_task = tokio::spawn(async move {
452 while let Some(msg) = rx.recv().await {
453 if sender.send(msg).await.is_err() {
454 break;
455 }
456 }
457 });
458
459 let mut project_path: Option<String> = None;
461 let mut session_welcomed = false; let state_for_recv = state.clone();
465
466 let heartbeat_tx = tx.clone();
468
469 let mut heartbeat_task = tokio::spawn(async move {
471 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
472 interval.tick().await;
474
475 loop {
476 interval.tick().await;
477 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
479 break;
481 }
482 tracing::trace!("Sent heartbeat ping to MCP client");
483 }
484 });
485
486 let mut recv_task = tokio::spawn(async move {
488 while let Some(Ok(msg)) = receiver.next().await {
489 match msg {
490 Message::Text(text) => {
491 let parsed_msg = match ProtocolMessage::<serde_json::Value>::from_json(&text) {
493 Ok(msg) => msg,
494 Err(e) => {
495 tracing::warn!("Protocol error: {}", e);
496
497 let error_code = if e.contains("version mismatch") {
499 error_codes::UNSUPPORTED_VERSION
500 } else {
501 error_codes::INVALID_MESSAGE
502 };
503
504 let error_payload = ErrorPayload {
505 code: error_code.to_string(),
506 message: e.to_string(),
507 details: None,
508 };
509
510 let _ = send_protocol_message(&tx, "error", error_payload);
511 continue;
512 },
513 };
514
515 match parsed_msg.message_type.as_str() {
516 "hello" => {
517 let hello: HelloPayload =
519 match serde_json::from_value(parsed_msg.payload.clone()) {
520 Ok(h) => h,
521 Err(e) => {
522 tracing::warn!("Failed to parse hello payload: {}", e);
523 continue;
524 },
525 };
526
527 tracing::info!("Received hello from {} client", hello.entity_type);
528
529 let session_id = format!(
531 "{}-{}",
532 hello.entity_type,
533 chrono::Utc::now().timestamp_millis()
534 );
535
536 let welcome_payload = WelcomePayload {
538 session_id,
539 capabilities: vec![], };
541
542 if send_protocol_message(&tx, "welcome", welcome_payload).is_ok() {
543 session_welcomed = true;
544 tracing::debug!("Sent welcome message");
545 } else {
546 tracing::error!("Failed to send welcome message");
547 }
548 },
549 "register" => {
550 if !session_welcomed {
552 tracing::warn!(
553 "MCP client registered without hello handshake (legacy client detected)"
554 );
555 }
556
557 let project: ProjectInfo =
559 match serde_json::from_value(parsed_msg.payload.clone()) {
560 Ok(p) => p,
561 Err(e) => {
562 tracing::warn!("Failed to parse register payload: {}", e);
563 continue;
564 },
565 };
566 tracing::info!("MCP registering project: {}", project.name);
567
568 let path = project.path.clone();
569 let project_path_buf = std::path::PathBuf::from(&path);
570
571 let normalized_path = project_path_buf
574 .canonicalize()
575 .unwrap_or_else(|_| project_path_buf.clone());
576
577 let temp_dir = std::env::temp_dir()
579 .canonicalize()
580 .unwrap_or_else(|_| std::env::temp_dir());
581 let is_temp_path = normalized_path.starts_with(&temp_dir);
582
583 if is_temp_path {
584 tracing::warn!(
585 "Rejecting MCP registration for temporary/invalid path: {}",
586 path
587 );
588
589 let error_payload = ErrorPayload {
591 code: error_codes::INVALID_PATH.to_string(),
592 message: "Path is in temporary directory".to_string(),
593 details: Some(serde_json::json!({"path": path})),
594 };
595 let _ = send_protocol_message(&tx, "error", error_payload);
596
597 let _ = send_protocol_message(
599 &tx,
600 "registered",
601 RegisteredPayload { success: false },
602 );
603 continue; }
605
606 let conn = McpConnection {
608 tx: tx.clone(),
609 project: project.clone(),
610 connected_at: chrono::Utc::now(),
611 };
612
613 state_for_recv
614 .mcp_connections
615 .write()
616 .await
617 .insert(path.clone(), conn);
618 project_path = Some(path.clone());
619
620 tracing::info!("✓ MCP connected: {} ({})", project.name, path);
621
622 let _ = send_protocol_message(
624 &tx,
625 "registered",
626 RegisteredPayload { success: true },
627 );
628
629 let mut project_info = project.clone();
631 project_info.mcp_connected = true;
632 let ui_msg = ProtocolMessage::new(
633 "project_online",
634 ProjectOnlinePayload {
635 project: project_info,
636 },
637 );
638 if let Ok(json) = ui_msg.to_json() {
639 state_for_recv.broadcast_to_ui(&json).await;
640 } else {
641 tracing::error!("Failed to serialize project_online message");
642 }
643 },
644 "pong" => {
645 tracing::trace!("Received pong from MCP client - heartbeat confirmed");
647 },
648 "goodbye" => {
649 if let Ok(goodbye_payload) =
651 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
652 {
653 if let Some(reason) = goodbye_payload.reason {
654 tracing::info!("MCP client closing connection: {}", reason);
655 } else {
656 tracing::info!("MCP client closing connection gracefully");
657 }
658 }
659 break;
661 },
662 "db_operation" => {
663 tracing::debug!(
666 "Received db_operation from MCP, forwarding to UI clients"
667 );
668 state_for_recv.broadcast_to_ui(&text).await;
669 },
670 _ => {
671 tracing::warn!("Unknown message type: {}", parsed_msg.message_type);
672 },
673 }
674 },
675 Message::Close(_) => {
676 tracing::info!("MCP client closed WebSocket");
677 break;
678 },
679 _ => {},
680 }
681 }
682
683 project_path
684 });
685
686 tokio::select! {
688 _ = (&mut send_task) => {
689 recv_task.abort();
690 heartbeat_task.abort();
691 }
692 project_path_result = (&mut recv_task) => {
693 send_task.abort();
694 heartbeat_task.abort();
695 if let Ok(Some(path)) = project_path_result {
696 state.mcp_connections.write().await.remove(&path);
698
699 tracing::info!("MCP disconnected: {}", path);
700
701 let ui_msg = ProtocolMessage::new(
703 "project_offline",
704 ProjectOfflinePayload { project_path: path.clone() },
705 );
706 if let Ok(json) = ui_msg.to_json() {
707 state.broadcast_to_ui(&json).await;
708 } else {
709 tracing::error!("Failed to serialize project_offline message");
710 }
711
712 tracing::info!("MCP disconnected: {}", path);
713 }
714 }
715 _ = (&mut heartbeat_task) => {
716 send_task.abort();
717 recv_task.abort();
718 }
719 }
720}
721
722pub async fn handle_ui_websocket(
724 ws: WebSocketUpgrade,
725 State(app_state): State<crate::dashboard::server::AppState>,
726) -> impl IntoResponse {
727 ws.on_upgrade(move |socket| handle_ui_socket(socket, app_state))
728}
729
730async fn handle_ui_socket(socket: WebSocket, app_state: crate::dashboard::server::AppState) {
731 let (mut sender, mut receiver) = socket.split();
732 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
733
734 let mut send_task = tokio::spawn(async move {
736 while let Some(msg) = rx.recv().await {
737 if sender.send(msg).await.is_err() {
738 break;
739 }
740 }
741 });
742
743 let conn = UiConnection {
749 tx: tx.clone(),
750 connected_at: chrono::Utc::now(),
751 };
752 let conn_index = {
753 let mut connections = app_state.ws_state.ui_connections.write().await;
754 connections.push(conn);
755 connections.len() - 1
756 };
757
758 tracing::info!("UI client connected");
759
760 let app_state_for_recv = app_state.clone();
762
763 let heartbeat_tx = tx.clone();
765
766 let mut heartbeat_task = tokio::spawn(async move {
768 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
769 interval.tick().await;
771
772 loop {
773 interval.tick().await;
774 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
775 break;
777 }
778 tracing::trace!("Sent heartbeat ping to UI client");
779 }
780 });
781
782 let mut recv_task = tokio::spawn(async move {
784 while let Some(Ok(msg)) = receiver.next().await {
785 match msg {
786 Message::Text(text) => {
787 if let Ok(parsed_msg) =
789 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
790 {
791 match parsed_msg.message_type.as_str() {
792 "hello" => {
793 if let Ok(hello) =
795 serde_json::from_value::<HelloPayload>(parsed_msg.payload)
796 {
797 tracing::info!(
798 "Received hello from {} client",
799 hello.entity_type
800 );
801
802 let session_id = format!(
804 "{}-{}",
805 hello.entity_type,
806 chrono::Utc::now().timestamp_millis()
807 );
808
809 let welcome_payload = WelcomePayload {
811 session_id,
812 capabilities: vec![],
813 };
814
815 let _ = send_protocol_message(&tx, "welcome", welcome_payload);
816 tracing::debug!("Sent welcome message to UI");
817
818 let current_projects = {
821 let port = app_state_for_recv.port;
822 match app_state_for_recv.get_active_project().await {
823 Some(active) => {
824 app_state_for_recv
825 .ws_state
826 .get_online_projects_with_current(
827 &active.name,
828 &active.path,
829 &active.db_path,
830 &app_state_for_recv.host_project,
831 port,
832 )
833 .await
834 },
835 None => {
836 vec![app_state_for_recv.host_project.clone()]
837 },
838 }
839 };
840 let _ = send_protocol_message(
841 &tx,
842 "init",
843 InitPayload {
844 projects: current_projects,
845 },
846 );
847 }
848 },
849 "pong" => {
850 tracing::trace!("Received pong from UI");
851 },
852 "goodbye" => {
853 if let Ok(goodbye_payload) =
855 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
856 {
857 if let Some(reason) = goodbye_payload.reason {
858 tracing::info!("UI client closing: {}", reason);
859 } else {
860 tracing::info!("UI client closing gracefully");
861 }
862 }
863 break;
864 },
865 _ => {
866 tracing::trace!(
867 "Received from UI: {} ({})",
868 parsed_msg.message_type,
869 text
870 );
871 },
872 }
873 } else {
874 tracing::trace!("Received non-protocol message from UI: {}", text);
875 }
876 },
877 Message::Pong(_) => {
878 tracing::trace!("Received WebSocket pong from UI");
879 },
880 Message::Close(_) => {
881 tracing::info!("UI client closed WebSocket");
882 break;
883 },
884 _ => {},
885 }
886 }
887 });
888
889 tokio::select! {
891 _ = (&mut send_task) => {
892 recv_task.abort();
893 heartbeat_task.abort();
894 }
895 _ = (&mut recv_task) => {
896 send_task.abort();
897 heartbeat_task.abort();
898 }
899 _ = (&mut heartbeat_task) => {
900 send_task.abort();
901 recv_task.abort();
902 }
903 }
904
905 app_state
907 .ws_state
908 .ui_connections
909 .write()
910 .await
911 .swap_remove(conn_index);
912 tracing::info!("UI client disconnected");
913}