1pub const PROTOCOL_VERSION: &str = "1.0";
6
7use axum::{
8 extract::{
9 ws::{Message, WebSocket},
10 State, WebSocketUpgrade,
11 },
12 response::IntoResponse,
13};
14use futures_util::{SinkExt, StreamExt};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20#[derive(Debug, Serialize, Deserialize)]
22pub struct ProtocolMessage<T> {
23 pub version: String,
25 #[serde(rename = "type")]
27 pub message_type: String,
28 pub payload: T,
30 pub timestamp: String,
32}
33
34impl<T> ProtocolMessage<T>
35where
36 T: Serialize,
37{
38 pub fn new(message_type: impl Into<String>, payload: T) -> Self {
40 Self {
41 version: PROTOCOL_VERSION.to_string(),
42 message_type: message_type.into(),
43 payload,
44 timestamp: chrono::Utc::now().to_rfc3339(),
45 }
46 }
47
48 pub fn to_json(&self) -> Result<String, serde_json::Error> {
50 serde_json::to_string(self)
51 }
52}
53
54impl<T> ProtocolMessage<T>
55where
56 T: for<'de> Deserialize<'de>,
57{
58 pub fn from_json(json: &str) -> Result<Self, String> {
60 let msg: Self = serde_json::from_str(json)
61 .map_err(|e| format!("Failed to parse protocol message: {}", e))?;
62
63 let expected_major = PROTOCOL_VERSION.split('.').next().unwrap_or("1");
65 let received_major = msg.version.split('.').next().unwrap_or("0");
66
67 if expected_major != received_major {
68 return Err(format!(
69 "Protocol version mismatch: expected {}, got {}",
70 PROTOCOL_VERSION, msg.version
71 ));
72 }
73
74 Ok(msg)
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProjectInfo {
81 pub path: String,
82 pub name: String,
83 pub db_path: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub agent: Option<String>,
86 pub mcp_connected: bool,
88 pub is_online: bool,
90}
91
92#[derive(Debug)]
94pub struct McpConnection {
95 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
96 pub project: ProjectInfo,
97 pub connected_at: chrono::DateTime<chrono::Utc>,
98}
99
100#[derive(Debug)]
102pub struct UiConnection {
103 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
104 pub connected_at: chrono::DateTime<chrono::Utc>,
105}
106
107#[derive(Clone)]
109pub struct WebSocketState {
110 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
112 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
114}
115
116impl Default for WebSocketState {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl WebSocketState {
123 pub fn new() -> Self {
124 Self {
125 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
126 ui_connections: Arc::new(RwLock::new(Vec::new())),
127 }
128 }
129
130 pub async fn broadcast_to_ui(&self, message: &str) {
132 let connections = self.ui_connections.read().await;
133 for conn in connections.iter() {
134 let _ = conn.tx.send(Message::Text(message.to_string()));
135 }
136 }
137
138 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
140 let connections = self.mcp_connections.read().await;
142
143 connections
144 .values()
145 .map(|conn| {
146 let mut project = conn.project.clone();
147 project.mcp_connected = true; project
149 })
150 .collect()
151 }
152
153 pub async fn get_online_projects_with_current(
157 &self,
158 current_project_name: &str,
159 current_project_path: &std::path::Path,
160 current_db_path: &std::path::Path,
161 _port: u16,
162 ) -> Vec<ProjectInfo> {
163 let connections = self.mcp_connections.read().await;
164 let current_path_str = current_project_path.display().to_string();
165
166 let mut projects = Vec::new();
167
168 let current_has_mcp = connections
171 .values()
172 .any(|conn| conn.project.path == current_path_str);
173
174 projects.push(ProjectInfo {
175 name: current_project_name.to_string(),
176 path: current_path_str.clone(),
177 db_path: current_db_path.display().to_string(),
178 agent: None, mcp_connected: current_has_mcp,
180 is_online: true, });
182
183 for conn in connections.values() {
185 if conn.project.path != current_path_str {
186 let mut project = conn.project.clone();
187 project.mcp_connected = true;
188 project.is_online = true; projects.push(project);
190 }
191 }
192
193 projects
194 }
195}
196
197#[derive(Debug, Serialize, Deserialize)]
203pub struct RegisterPayload {
204 pub project: ProjectInfo,
205}
206
207#[derive(Debug, Serialize, Deserialize)]
209pub struct RegisteredPayload {
210 pub success: bool,
211}
212
213#[derive(Debug, Serialize, Deserialize)]
215pub struct EmptyPayload {}
216
217#[derive(Debug, Serialize, Deserialize)]
219pub struct InitPayload {
220 pub projects: Vec<ProjectInfo>,
221}
222
223#[derive(Debug, Serialize, Deserialize)]
225pub struct ProjectOnlinePayload {
226 pub project: ProjectInfo,
227}
228
229#[derive(Debug, Serialize, Deserialize)]
231pub struct ProjectOfflinePayload {
232 pub project_path: String,
233}
234
235#[derive(Debug, Serialize, Deserialize)]
237pub struct HelloPayload {
238 pub entity_type: String,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub capabilities: Option<Vec<String>>,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
247pub struct WelcomePayload {
248 pub capabilities: Vec<String>,
250 pub session_id: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
256pub struct GoodbyePayload {
257 #[serde(skip_serializing_if = "Option::is_none")]
259 pub reason: Option<String>,
260}
261
262#[derive(Debug, Serialize, Deserialize)]
264pub struct ErrorPayload {
265 pub code: String,
267 pub message: String,
269 #[serde(skip_serializing_if = "Option::is_none")]
271 pub details: Option<serde_json::Value>,
272}
273
274pub mod error_codes {
276 pub const UNSUPPORTED_VERSION: &str = "unsupported_version";
277 pub const INVALID_MESSAGE: &str = "invalid_message";
278 pub const INVALID_PATH: &str = "invalid_path";
279 pub const REGISTRATION_FAILED: &str = "registration_failed";
280 pub const INTERNAL_ERROR: &str = "internal_error";
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone)]
285pub struct DatabaseOperationPayload {
286 pub operation: String,
288
289 pub entity: String,
291
292 pub affected_ids: Vec<i64>,
294
295 #[serde(skip_serializing_if = "Option::is_none")]
298 pub data: Option<serde_json::Value>,
299
300 pub project_path: String,
302}
303
304impl DatabaseOperationPayload {
305 pub fn new(
307 operation: impl Into<String>,
308 entity: impl Into<String>,
309 affected_ids: Vec<i64>,
310 data: Option<serde_json::Value>,
311 project_path: impl Into<String>,
312 ) -> Self {
313 Self {
314 operation: operation.into(),
315 entity: entity.into(),
316 affected_ids,
317 data,
318 project_path: project_path.into(),
319 }
320 }
321
322 pub fn task_created(
324 task_id: i64,
325 task_data: serde_json::Value,
326 project_path: impl Into<String>,
327 ) -> Self {
328 Self::new(
329 "create",
330 "task",
331 vec![task_id],
332 Some(task_data),
333 project_path,
334 )
335 }
336
337 pub fn task_updated(
339 task_id: i64,
340 task_data: serde_json::Value,
341 project_path: impl Into<String>,
342 ) -> Self {
343 Self::new(
344 "update",
345 "task",
346 vec![task_id],
347 Some(task_data),
348 project_path,
349 )
350 }
351
352 pub fn task_deleted(task_id: i64, project_path: impl Into<String>) -> Self {
354 Self::new("delete", "task", vec![task_id], None, project_path)
355 }
356
357 pub fn task_read(task_id: i64, project_path: impl Into<String>) -> Self {
359 Self::new("read", "task", vec![task_id], None, project_path)
360 }
361
362 pub fn event_created(
364 event_id: i64,
365 event_data: serde_json::Value,
366 project_path: impl Into<String>,
367 ) -> Self {
368 Self::new(
369 "create",
370 "event",
371 vec![event_id],
372 Some(event_data),
373 project_path,
374 )
375 }
376
377 pub fn event_updated(
379 event_id: i64,
380 event_data: serde_json::Value,
381 project_path: impl Into<String>,
382 ) -> Self {
383 Self::new(
384 "update",
385 "event",
386 vec![event_id],
387 Some(event_data),
388 project_path,
389 )
390 }
391
392 pub fn event_deleted(event_id: i64, project_path: impl Into<String>) -> Self {
394 Self::new("delete", "event", vec![event_id], None, project_path)
395 }
396}
397
398fn send_protocol_message<T: Serialize>(
404 tx: &tokio::sync::mpsc::UnboundedSender<Message>,
405 message_type: &str,
406 payload: T,
407) -> Result<(), String> {
408 let protocol_msg = ProtocolMessage::new(message_type, payload);
409 let json = protocol_msg
410 .to_json()
411 .map_err(|e| format!("Failed to serialize message: {}", e))?;
412
413 tx.send(Message::Text(json))
414 .map_err(|_| "Failed to send message: channel closed".to_string())
415}
416
417pub async fn handle_mcp_websocket(
419 ws: WebSocketUpgrade,
420 State(app_state): State<crate::dashboard::server::AppState>,
421) -> impl IntoResponse {
422 ws.on_upgrade(move |socket| handle_mcp_socket(socket, app_state.ws_state))
423}
424
425async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
426 let (mut sender, mut receiver) = socket.split();
427 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
428
429 let mut send_task = tokio::spawn(async move {
431 while let Some(msg) = rx.recv().await {
432 if sender.send(msg).await.is_err() {
433 break;
434 }
435 }
436 });
437
438 let mut project_path: Option<String> = None;
440 let mut session_welcomed = false; let state_for_recv = state.clone();
444
445 let heartbeat_tx = tx.clone();
447
448 let mut heartbeat_task = tokio::spawn(async move {
450 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
451 interval.tick().await;
453
454 loop {
455 interval.tick().await;
456 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
458 break;
460 }
461 tracing::trace!("Sent heartbeat ping to MCP client");
462 }
463 });
464
465 let mut recv_task = tokio::spawn(async move {
467 while let Some(Ok(msg)) = receiver.next().await {
468 match msg {
469 Message::Text(text) => {
470 let parsed_msg = match ProtocolMessage::<serde_json::Value>::from_json(&text) {
472 Ok(msg) => msg,
473 Err(e) => {
474 tracing::warn!("Protocol error: {}", e);
475
476 let error_code = if e.contains("version mismatch") {
478 error_codes::UNSUPPORTED_VERSION
479 } else {
480 error_codes::INVALID_MESSAGE
481 };
482
483 let error_payload = ErrorPayload {
484 code: error_code.to_string(),
485 message: e.to_string(),
486 details: None,
487 };
488
489 let _ = send_protocol_message(&tx, "error", error_payload);
490 continue;
491 },
492 };
493
494 match parsed_msg.message_type.as_str() {
495 "hello" => {
496 let hello: HelloPayload =
498 match serde_json::from_value(parsed_msg.payload.clone()) {
499 Ok(h) => h,
500 Err(e) => {
501 tracing::warn!("Failed to parse hello payload: {}", e);
502 continue;
503 },
504 };
505
506 tracing::info!("Received hello from {} client", hello.entity_type);
507
508 let session_id = format!(
510 "{}-{}",
511 hello.entity_type,
512 chrono::Utc::now().timestamp_millis()
513 );
514
515 let welcome_payload = WelcomePayload {
517 session_id,
518 capabilities: vec![], };
520
521 if send_protocol_message(&tx, "welcome", welcome_payload).is_ok() {
522 session_welcomed = true;
523 tracing::debug!("Sent welcome message");
524 } else {
525 tracing::error!("Failed to send welcome message");
526 }
527 },
528 "register" => {
529 if !session_welcomed {
531 tracing::warn!(
532 "MCP client registered without hello handshake (legacy client detected)"
533 );
534 }
535
536 let project: ProjectInfo =
538 match serde_json::from_value(parsed_msg.payload.clone()) {
539 Ok(p) => p,
540 Err(e) => {
541 tracing::warn!("Failed to parse register payload: {}", e);
542 continue;
543 },
544 };
545 tracing::info!("MCP registering project: {}", project.name);
546
547 let path = project.path.clone();
548 let project_path_buf = std::path::PathBuf::from(&path);
549
550 let normalized_path = project_path_buf
553 .canonicalize()
554 .unwrap_or_else(|_| project_path_buf.clone());
555
556 let temp_dir = std::env::temp_dir()
558 .canonicalize()
559 .unwrap_or_else(|_| std::env::temp_dir());
560 let is_temp_path = normalized_path.starts_with(&temp_dir);
561
562 if is_temp_path {
563 tracing::warn!(
564 "Rejecting MCP registration for temporary/invalid path: {}",
565 path
566 );
567
568 let error_payload = ErrorPayload {
570 code: error_codes::INVALID_PATH.to_string(),
571 message: "Path is in temporary directory".to_string(),
572 details: Some(serde_json::json!({"path": path})),
573 };
574 let _ = send_protocol_message(&tx, "error", error_payload);
575
576 let _ = send_protocol_message(
578 &tx,
579 "registered",
580 RegisteredPayload { success: false },
581 );
582 continue; }
584
585 let conn = McpConnection {
587 tx: tx.clone(),
588 project: project.clone(),
589 connected_at: chrono::Utc::now(),
590 };
591
592 state_for_recv
593 .mcp_connections
594 .write()
595 .await
596 .insert(path.clone(), conn);
597 project_path = Some(path.clone());
598
599 tracing::info!("✓ MCP connected: {} ({})", project.name, path);
600
601 let _ = send_protocol_message(
603 &tx,
604 "registered",
605 RegisteredPayload { success: true },
606 );
607
608 let mut project_info = project.clone();
610 project_info.mcp_connected = true;
611 let ui_msg = ProtocolMessage::new(
612 "project_online",
613 ProjectOnlinePayload {
614 project: project_info,
615 },
616 );
617 state_for_recv
618 .broadcast_to_ui(&ui_msg.to_json().unwrap())
619 .await;
620 },
621 "pong" => {
622 tracing::trace!("Received pong from MCP client - heartbeat confirmed");
624 },
625 "goodbye" => {
626 if let Ok(goodbye_payload) =
628 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
629 {
630 if let Some(reason) = goodbye_payload.reason {
631 tracing::info!("MCP client closing connection: {}", reason);
632 } else {
633 tracing::info!("MCP client closing connection gracefully");
634 }
635 }
636 break;
638 },
639 "db_operation" => {
640 tracing::debug!(
643 "Received db_operation from MCP, forwarding to UI clients"
644 );
645 state_for_recv.broadcast_to_ui(&text).await;
646 },
647 _ => {
648 tracing::warn!("Unknown message type: {}", parsed_msg.message_type);
649 },
650 }
651 },
652 Message::Close(_) => {
653 tracing::info!("MCP client closed WebSocket");
654 break;
655 },
656 _ => {},
657 }
658 }
659
660 project_path
661 });
662
663 tokio::select! {
665 _ = (&mut send_task) => {
666 recv_task.abort();
667 heartbeat_task.abort();
668 }
669 project_path_result = (&mut recv_task) => {
670 send_task.abort();
671 heartbeat_task.abort();
672 if let Ok(Some(path)) = project_path_result {
673 state.mcp_connections.write().await.remove(&path);
675
676 tracing::info!("MCP disconnected: {}", path);
677
678 let ui_msg = ProtocolMessage::new(
680 "project_offline",
681 ProjectOfflinePayload { project_path: path.clone() },
682 );
683 state
684 .broadcast_to_ui(&ui_msg.to_json().unwrap())
685 .await;
686
687 tracing::info!("MCP disconnected: {}", path);
688 }
689 }
690 _ = (&mut heartbeat_task) => {
691 send_task.abort();
692 recv_task.abort();
693 }
694 }
695}
696
697pub async fn handle_ui_websocket(
699 ws: WebSocketUpgrade,
700 State(app_state): State<crate::dashboard::server::AppState>,
701) -> impl IntoResponse {
702 ws.on_upgrade(move |socket| handle_ui_socket(socket, app_state))
703}
704
705async fn handle_ui_socket(socket: WebSocket, app_state: crate::dashboard::server::AppState) {
706 let (mut sender, mut receiver) = socket.split();
707 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
708
709 let mut send_task = tokio::spawn(async move {
711 while let Some(msg) = rx.recv().await {
712 if sender.send(msg).await.is_err() {
713 break;
714 }
715 }
716 });
717
718 let conn = UiConnection {
724 tx: tx.clone(),
725 connected_at: chrono::Utc::now(),
726 };
727 let conn_index = {
728 let mut connections = app_state.ws_state.ui_connections.write().await;
729 connections.push(conn);
730 connections.len() - 1
731 };
732
733 tracing::info!("UI client connected");
734
735 let app_state_for_recv = app_state.clone();
737
738 let heartbeat_tx = tx.clone();
740
741 let mut heartbeat_task = tokio::spawn(async move {
743 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
744 interval.tick().await;
746
747 loop {
748 interval.tick().await;
749 if send_protocol_message(&heartbeat_tx, "ping", EmptyPayload {}).is_err() {
750 break;
752 }
753 tracing::trace!("Sent heartbeat ping to UI client");
754 }
755 });
756
757 let mut recv_task = tokio::spawn(async move {
759 while let Some(Ok(msg)) = receiver.next().await {
760 match msg {
761 Message::Text(text) => {
762 if let Ok(parsed_msg) =
764 serde_json::from_str::<ProtocolMessage<serde_json::Value>>(&text)
765 {
766 match parsed_msg.message_type.as_str() {
767 "hello" => {
768 if let Ok(hello) =
770 serde_json::from_value::<HelloPayload>(parsed_msg.payload)
771 {
772 tracing::info!(
773 "Received hello from {} client",
774 hello.entity_type
775 );
776
777 let session_id = format!(
779 "{}-{}",
780 hello.entity_type,
781 chrono::Utc::now().timestamp_millis()
782 );
783
784 let welcome_payload = WelcomePayload {
786 session_id,
787 capabilities: vec![],
788 };
789
790 let _ = send_protocol_message(&tx, "welcome", welcome_payload);
791 tracing::debug!("Sent welcome message to UI");
792
793 let current_projects = {
796 let current_project =
797 app_state_for_recv.current_project.read().await;
798 let port = app_state_for_recv.port;
799 app_state_for_recv
800 .ws_state
801 .get_online_projects_with_current(
802 ¤t_project.project_name,
803 ¤t_project.project_path,
804 ¤t_project.db_path,
805 port,
806 )
807 .await
808 };
809 let _ = send_protocol_message(
810 &tx,
811 "init",
812 InitPayload {
813 projects: current_projects,
814 },
815 );
816 }
817 },
818 "pong" => {
819 tracing::trace!("Received pong from UI");
820 },
821 "goodbye" => {
822 if let Ok(goodbye_payload) =
824 serde_json::from_value::<GoodbyePayload>(parsed_msg.payload)
825 {
826 if let Some(reason) = goodbye_payload.reason {
827 tracing::info!("UI client closing: {}", reason);
828 } else {
829 tracing::info!("UI client closing gracefully");
830 }
831 }
832 break;
833 },
834 _ => {
835 tracing::trace!(
836 "Received from UI: {} ({})",
837 parsed_msg.message_type,
838 text
839 );
840 },
841 }
842 } else {
843 tracing::trace!("Received non-protocol message from UI: {}", text);
844 }
845 },
846 Message::Pong(_) => {
847 tracing::trace!("Received WebSocket pong from UI");
848 },
849 Message::Close(_) => {
850 tracing::info!("UI client closed WebSocket");
851 break;
852 },
853 _ => {},
854 }
855 }
856 });
857
858 tokio::select! {
860 _ = (&mut send_task) => {
861 recv_task.abort();
862 heartbeat_task.abort();
863 }
864 _ = (&mut recv_task) => {
865 send_task.abort();
866 heartbeat_task.abort();
867 }
868 _ = (&mut heartbeat_task) => {
869 send_task.abort();
870 recv_task.abort();
871 }
872 }
873
874 app_state
876 .ws_state
877 .ui_connections
878 .write()
879 .await
880 .swap_remove(conn_index);
881 tracing::info!("UI client disconnected");
882}