1use axum::{
5 extract::{
6 ws::{Message, WebSocket},
7 State, WebSocketUpgrade,
8 },
9 response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20 pub path: String,
21 pub name: String,
22 pub db_path: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub agent: Option<String>,
25}
26
27#[derive(Debug)]
29pub struct McpConnection {
30 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31 pub project: ProjectInfo,
32 pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35#[derive(Debug)]
37pub struct UiConnection {
38 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39 pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42#[derive(Clone)]
44pub struct WebSocketState {
45 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl WebSocketState {
58 pub fn new() -> Self {
59 Self {
60 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61 ui_connections: Arc::new(RwLock::new(Vec::new())),
62 }
63 }
64
65 pub async fn broadcast_to_ui(&self, message: &str) {
67 let connections = self.ui_connections.read().await;
68 for conn in connections.iter() {
69 let _ = conn.tx.send(Message::Text(message.to_string()));
70 }
71 }
72
73 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75 match crate::dashboard::registry::ProjectRegistry::load() {
78 Ok(registry) => registry
79 .projects
80 .iter()
81 .filter(|p| p.mcp_connected)
82 .map(|p| ProjectInfo {
83 name: p.name.clone(),
84 path: p.path.display().to_string(),
85 db_path: p.db_path.display().to_string(),
86 agent: p.mcp_agent.clone(),
87 })
88 .collect(),
89 Err(e) => {
90 tracing::warn!("Failed to load registry for online projects: {}", e);
91 Vec::new()
92 },
93 }
94 }
95}
96
97#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101 #[serde(rename = "register")]
102 Register { project: ProjectInfo },
103 #[serde(rename = "ping")]
104 Ping,
105}
106
107#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111 #[serde(rename = "registered")]
112 Registered { success: bool },
113 #[serde(rename = "pong")]
114 Pong,
115}
116
117#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121 #[serde(rename = "init")]
122 Init { projects: Vec<ProjectInfo> },
123 #[serde(rename = "project_online")]
124 ProjectOnline { project: ProjectInfo },
125 #[serde(rename = "project_offline")]
126 ProjectOffline { project_path: String },
127 #[serde(rename = "ping")]
128 Ping,
129}
130
131pub async fn handle_mcp_websocket(
133 ws: WebSocketUpgrade,
134 State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136 ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140 let (mut sender, mut receiver) = socket.split();
141 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143 let mut send_task = tokio::spawn(async move {
145 while let Some(msg) = rx.recv().await {
146 if sender.send(msg).await.is_err() {
147 break;
148 }
149 }
150 });
151
152 let mut project_path: Option<String> = None;
154
155 let state_for_recv = state.clone();
157
158 let heartbeat_tx = tx.clone();
160
161 let mut heartbeat_task = tokio::spawn(async move {
163 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164 loop {
165 interval.tick().await;
166 let ping_msg = McpResponse::Pong; if heartbeat_tx
168 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169 .is_err()
170 {
171 break;
173 }
174 tracing::trace!("Sent heartbeat to MCP client");
175 }
176 });
177
178 let mut recv_task = tokio::spawn(async move {
180 while let Some(Ok(msg)) = receiver.next().await {
181 match msg {
182 Message::Text(text) => {
183 match serde_json::from_str::<McpMessage>(&text) {
185 Ok(McpMessage::Register { project }) => {
186 tracing::info!("MCP registering project: {}", project.name);
187
188 let path = project.path.clone();
189 let project_path_buf = std::path::PathBuf::from(&path);
190
191 let normalized_path = project_path_buf
194 .canonicalize()
195 .unwrap_or_else(|_| project_path_buf.clone());
196
197 let temp_dir = std::env::temp_dir()
199 .canonicalize()
200 .unwrap_or_else(|_| std::env::temp_dir());
201 let is_temp_path = normalized_path.starts_with(&temp_dir);
202
203 if is_temp_path {
204 tracing::warn!(
205 "Rejecting MCP registration for temporary/invalid path: {}",
206 path
207 );
208
209 let response = McpResponse::Registered { success: false };
211 let _ = tx
212 .send(Message::Text(serde_json::to_string(&response).unwrap()));
213 continue; }
215
216 let conn = McpConnection {
218 tx: tx.clone(),
219 project: project.clone(),
220 connected_at: chrono::Utc::now(),
221 };
222
223 state_for_recv
224 .mcp_connections
225 .write()
226 .await
227 .insert(path.clone(), conn);
228 project_path = Some(path.clone());
229
230 match crate::dashboard::registry::ProjectRegistry::load() {
232 Ok(mut registry) => {
233 if let Err(e) = registry.register_mcp_connection(
234 &project_path_buf,
235 project.agent.clone(),
236 ) {
237 tracing::warn!(
238 "Failed to update Registry for MCP connection: {}",
239 e
240 );
241 } else {
242 tracing::info!(
243 "✓ Updated Registry: {} is now mcp_connected=true",
244 project.name
245 );
246 }
247 },
248 Err(e) => {
249 tracing::warn!("Failed to load Registry: {}", e);
250 },
251 }
252
253 let response = McpResponse::Registered { success: true };
255 let _ =
256 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
257
258 let ui_msg = UiMessage::ProjectOnline { project };
260 state_for_recv
261 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
262 .await;
263 },
264 Ok(McpMessage::Ping) => {
265 let response = McpResponse::Pong;
267 let _ =
268 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
269 },
270 Err(e) => {
271 tracing::warn!("Failed to parse MCP message: {}", e);
272 },
273 }
274 },
275 Message::Close(_) => {
276 break;
277 },
278 _ => {},
279 }
280 }
281
282 project_path
283 });
284
285 tokio::select! {
287 _ = (&mut send_task) => {
288 recv_task.abort();
289 heartbeat_task.abort();
290 }
291 project_path_result = (&mut recv_task) => {
292 send_task.abort();
293 heartbeat_task.abort();
294 if let Ok(Some(path)) = project_path_result {
295 state.mcp_connections.write().await.remove(&path);
297
298 let project_path_buf = std::path::PathBuf::from(&path);
300 match crate::dashboard::registry::ProjectRegistry::load() {
301 Ok(mut registry) => {
302 if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
303 tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
304 } else {
305 tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
306 }
307 }
308 Err(e) => {
309 tracing::warn!("Failed to load Registry: {}", e);
310 }
311 }
312
313 let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
315 state
316 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
317 .await;
318
319 tracing::info!("MCP disconnected: {}", path);
320 }
321 }
322 _ = (&mut heartbeat_task) => {
323 send_task.abort();
324 recv_task.abort();
325 }
326 }
327}
328
329pub async fn handle_ui_websocket(
331 ws: WebSocketUpgrade,
332 State(state): State<WebSocketState>,
333) -> impl IntoResponse {
334 ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
335}
336
337async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
338 let (mut sender, mut receiver) = socket.split();
339 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
340
341 let mut send_task = tokio::spawn(async move {
343 while let Some(msg) = rx.recv().await {
344 if sender.send(msg).await.is_err() {
345 break;
346 }
347 }
348 });
349
350 let projects = state.get_online_projects().await;
352 let init_msg = UiMessage::Init { projects };
353 let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
354
355 let conn = UiConnection {
357 tx: tx.clone(),
358 connected_at: chrono::Utc::now(),
359 };
360 let conn_index = {
361 let mut connections = state.ui_connections.write().await;
362 connections.push(conn);
363 connections.len() - 1
364 };
365
366 tracing::info!("UI client connected");
367
368 let heartbeat_tx = tx.clone();
370
371 let mut heartbeat_task = tokio::spawn(async move {
373 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
374 loop {
375 interval.tick().await;
376 let ping_msg = UiMessage::Ping;
377 if heartbeat_tx
378 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
379 .is_err()
380 {
381 break;
383 }
384 tracing::trace!("Sent heartbeat ping to UI client");
385 }
386 });
387
388 let mut recv_task = tokio::spawn(async move {
390 while let Some(Ok(msg)) = receiver.next().await {
391 match msg {
392 Message::Text(text) => {
393 tracing::trace!("Received from UI: {}", text);
395 },
396 Message::Pong(_) => {
397 tracing::trace!("Received pong from UI");
398 },
399 Message::Close(_) => {
400 break;
401 },
402 _ => {},
403 }
404 }
405 });
406
407 tokio::select! {
409 _ = (&mut send_task) => {
410 recv_task.abort();
411 heartbeat_task.abort();
412 }
413 _ = (&mut recv_task) => {
414 send_task.abort();
415 heartbeat_task.abort();
416 }
417 _ = (&mut heartbeat_task) => {
418 send_task.abort();
419 recv_task.abort();
420 }
421 }
422
423 state.ui_connections.write().await.swap_remove(conn_index);
425 tracing::info!("UI client disconnected");
426}