1use axum::{
5 extract::{
6 ws::{Message, WebSocket},
7 State, WebSocketUpgrade,
8 },
9 response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20 pub path: String,
21 pub name: String,
22 pub db_path: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub agent: Option<String>,
25}
26
27#[derive(Debug)]
29pub struct McpConnection {
30 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31 pub project: ProjectInfo,
32 pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35#[derive(Debug)]
37pub struct UiConnection {
38 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39 pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42#[derive(Clone)]
44pub struct WebSocketState {
45 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl WebSocketState {
58 pub fn new() -> Self {
59 Self {
60 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61 ui_connections: Arc::new(RwLock::new(Vec::new())),
62 }
63 }
64
65 pub async fn broadcast_to_ui(&self, message: &str) {
67 let connections = self.ui_connections.read().await;
68 for conn in connections.iter() {
69 let _ = conn.tx.send(Message::Text(message.to_string()));
70 }
71 }
72
73 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75 match crate::dashboard::registry::ProjectRegistry::load() {
78 Ok(registry) => registry
79 .projects
80 .iter()
81 .filter(|p| p.mcp_connected)
82 .map(|p| ProjectInfo {
83 name: p.name.clone(),
84 path: p.path.display().to_string(),
85 db_path: p.db_path.display().to_string(),
86 agent: p.mcp_agent.clone(),
87 })
88 .collect(),
89 Err(e) => {
90 tracing::warn!("Failed to load registry for online projects: {}", e);
91 Vec::new()
92 },
93 }
94 }
95}
96
97#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101 #[serde(rename = "register")]
102 Register { project: ProjectInfo },
103 #[serde(rename = "ping")]
104 Ping,
105}
106
107#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111 #[serde(rename = "registered")]
112 Registered { success: bool },
113 #[serde(rename = "pong")]
114 Pong,
115}
116
117#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121 #[serde(rename = "init")]
122 Init { projects: Vec<ProjectInfo> },
123 #[serde(rename = "project_online")]
124 ProjectOnline { project: ProjectInfo },
125 #[serde(rename = "project_offline")]
126 ProjectOffline { project_path: String },
127 #[serde(rename = "ping")]
128 Ping,
129}
130
131pub async fn handle_mcp_websocket(
133 ws: WebSocketUpgrade,
134 State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136 ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140 let (mut sender, mut receiver) = socket.split();
141 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143 let mut send_task = tokio::spawn(async move {
145 while let Some(msg) = rx.recv().await {
146 if sender.send(msg).await.is_err() {
147 break;
148 }
149 }
150 });
151
152 let mut project_path: Option<String> = None;
154
155 let state_for_recv = state.clone();
157
158 let heartbeat_tx = tx.clone();
160
161 let mut heartbeat_task = tokio::spawn(async move {
163 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164 loop {
165 interval.tick().await;
166 let ping_msg = McpResponse::Pong; if heartbeat_tx
168 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169 .is_err()
170 {
171 break;
173 }
174 tracing::trace!("Sent heartbeat to MCP client");
175 }
176 });
177
178 let mut recv_task = tokio::spawn(async move {
180 while let Some(Ok(msg)) = receiver.next().await {
181 match msg {
182 Message::Text(text) => {
183 match serde_json::from_str::<McpMessage>(&text) {
185 Ok(McpMessage::Register { project }) => {
186 tracing::info!("MCP registering project: {}", project.name);
187
188 let path = project.path.clone();
189 let project_path_buf = std::path::PathBuf::from(&path);
190
191 let normalized_path = project_path_buf
194 .canonicalize()
195 .unwrap_or_else(|_| project_path_buf.clone());
196
197 let temp_dir = std::env::temp_dir();
198 let is_temp_path = normalized_path.starts_with(&temp_dir);
199
200 if is_temp_path {
201 tracing::warn!(
202 "Rejecting MCP registration for temporary/invalid path: {}",
203 path
204 );
205
206 let response = McpResponse::Registered { success: false };
208 let _ = tx
209 .send(Message::Text(serde_json::to_string(&response).unwrap()));
210 continue; }
212
213 let conn = McpConnection {
215 tx: tx.clone(),
216 project: project.clone(),
217 connected_at: chrono::Utc::now(),
218 };
219
220 state_for_recv
221 .mcp_connections
222 .write()
223 .await
224 .insert(path.clone(), conn);
225 project_path = Some(path.clone());
226
227 match crate::dashboard::registry::ProjectRegistry::load() {
229 Ok(mut registry) => {
230 if let Err(e) = registry.register_mcp_connection(
231 &project_path_buf,
232 project.agent.clone(),
233 ) {
234 tracing::warn!(
235 "Failed to update Registry for MCP connection: {}",
236 e
237 );
238 } else {
239 tracing::info!(
240 "✓ Updated Registry: {} is now mcp_connected=true",
241 project.name
242 );
243 }
244 },
245 Err(e) => {
246 tracing::warn!("Failed to load Registry: {}", e);
247 },
248 }
249
250 let response = McpResponse::Registered { success: true };
252 let _ =
253 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
254
255 let ui_msg = UiMessage::ProjectOnline { project };
257 state_for_recv
258 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
259 .await;
260 },
261 Ok(McpMessage::Ping) => {
262 let response = McpResponse::Pong;
264 let _ =
265 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
266 },
267 Err(e) => {
268 tracing::warn!("Failed to parse MCP message: {}", e);
269 },
270 }
271 },
272 Message::Close(_) => {
273 break;
274 },
275 _ => {},
276 }
277 }
278
279 project_path
280 });
281
282 tokio::select! {
284 _ = (&mut send_task) => {
285 recv_task.abort();
286 heartbeat_task.abort();
287 }
288 project_path_result = (&mut recv_task) => {
289 send_task.abort();
290 heartbeat_task.abort();
291 if let Ok(Some(path)) = project_path_result {
292 state.mcp_connections.write().await.remove(&path);
294
295 let project_path_buf = std::path::PathBuf::from(&path);
297 match crate::dashboard::registry::ProjectRegistry::load() {
298 Ok(mut registry) => {
299 if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
300 tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
301 } else {
302 tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
303 }
304 }
305 Err(e) => {
306 tracing::warn!("Failed to load Registry: {}", e);
307 }
308 }
309
310 let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
312 state
313 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
314 .await;
315
316 tracing::info!("MCP disconnected: {}", path);
317 }
318 }
319 _ = (&mut heartbeat_task) => {
320 send_task.abort();
321 recv_task.abort();
322 }
323 }
324}
325
326pub async fn handle_ui_websocket(
328 ws: WebSocketUpgrade,
329 State(state): State<WebSocketState>,
330) -> impl IntoResponse {
331 ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
332}
333
334async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
335 let (mut sender, mut receiver) = socket.split();
336 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
337
338 let mut send_task = tokio::spawn(async move {
340 while let Some(msg) = rx.recv().await {
341 if sender.send(msg).await.is_err() {
342 break;
343 }
344 }
345 });
346
347 let projects = state.get_online_projects().await;
349 let init_msg = UiMessage::Init { projects };
350 let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
351
352 let conn = UiConnection {
354 tx: tx.clone(),
355 connected_at: chrono::Utc::now(),
356 };
357 let conn_index = {
358 let mut connections = state.ui_connections.write().await;
359 connections.push(conn);
360 connections.len() - 1
361 };
362
363 tracing::info!("UI client connected");
364
365 let heartbeat_tx = tx.clone();
367
368 let mut heartbeat_task = tokio::spawn(async move {
370 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
371 loop {
372 interval.tick().await;
373 let ping_msg = UiMessage::Ping;
374 if heartbeat_tx
375 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
376 .is_err()
377 {
378 break;
380 }
381 tracing::trace!("Sent heartbeat ping to UI client");
382 }
383 });
384
385 let mut recv_task = tokio::spawn(async move {
387 while let Some(Ok(msg)) = receiver.next().await {
388 match msg {
389 Message::Text(text) => {
390 tracing::trace!("Received from UI: {}", text);
392 },
393 Message::Pong(_) => {
394 tracing::trace!("Received pong from UI");
395 },
396 Message::Close(_) => {
397 break;
398 },
399 _ => {},
400 }
401 }
402 });
403
404 tokio::select! {
406 _ = (&mut send_task) => {
407 recv_task.abort();
408 heartbeat_task.abort();
409 }
410 _ = (&mut recv_task) => {
411 send_task.abort();
412 heartbeat_task.abort();
413 }
414 _ = (&mut heartbeat_task) => {
415 send_task.abort();
416 recv_task.abort();
417 }
418 }
419
420 state.ui_connections.write().await.swap_remove(conn_index);
422 tracing::info!("UI client disconnected");
423}