1use axum::{
5 extract::{
6 ws::{Message, WebSocket},
7 State, WebSocketUpgrade,
8 },
9 response::IntoResponse,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProjectInfo {
20 pub path: String,
21 pub name: String,
22 pub db_path: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub agent: Option<String>,
25}
26
27#[derive(Debug)]
29pub struct McpConnection {
30 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
31 pub project: ProjectInfo,
32 pub connected_at: chrono::DateTime<chrono::Utc>,
33}
34
35#[derive(Debug)]
37pub struct UiConnection {
38 pub tx: tokio::sync::mpsc::UnboundedSender<Message>,
39 pub connected_at: chrono::DateTime<chrono::Utc>,
40}
41
42#[derive(Clone)]
44pub struct WebSocketState {
45 pub mcp_connections: Arc<RwLock<HashMap<String, McpConnection>>>,
47 pub ui_connections: Arc<RwLock<Vec<UiConnection>>>,
49}
50
51impl Default for WebSocketState {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl WebSocketState {
58 pub fn new() -> Self {
59 Self {
60 mcp_connections: Arc::new(RwLock::new(HashMap::new())),
61 ui_connections: Arc::new(RwLock::new(Vec::new())),
62 }
63 }
64
65 pub async fn broadcast_to_ui(&self, message: &str) {
67 let connections = self.ui_connections.read().await;
68 for conn in connections.iter() {
69 let _ = conn.tx.send(Message::Text(message.to_string()));
70 }
71 }
72
73 pub async fn get_online_projects(&self) -> Vec<ProjectInfo> {
75 match crate::dashboard::registry::ProjectRegistry::load() {
78 Ok(registry) => registry
79 .projects
80 .iter()
81 .filter(|p| p.mcp_connected)
82 .map(|p| ProjectInfo {
83 name: p.name.clone(),
84 path: p.path.display().to_string(),
85 db_path: p.db_path.display().to_string(),
86 agent: p.mcp_agent.clone(),
87 })
88 .collect(),
89 Err(e) => {
90 tracing::warn!("Failed to load registry for online projects: {}", e);
91 Vec::new()
92 },
93 }
94 }
95}
96
97#[derive(Debug, Deserialize)]
99#[serde(tag = "type")]
100enum McpMessage {
101 #[serde(rename = "register")]
102 Register { project: ProjectInfo },
103 #[serde(rename = "ping")]
104 Ping,
105}
106
107#[derive(Debug, Serialize)]
109#[serde(tag = "type")]
110enum McpResponse {
111 #[serde(rename = "registered")]
112 Registered { success: bool },
113 #[serde(rename = "pong")]
114 Pong,
115}
116
117#[derive(Debug, Serialize)]
119#[serde(tag = "type")]
120enum UiMessage {
121 #[serde(rename = "init")]
122 Init { projects: Vec<ProjectInfo> },
123 #[serde(rename = "project_online")]
124 ProjectOnline { project: ProjectInfo },
125 #[serde(rename = "project_offline")]
126 ProjectOffline { project_path: String },
127 #[serde(rename = "ping")]
128 Ping,
129}
130
131pub async fn handle_mcp_websocket(
133 ws: WebSocketUpgrade,
134 State(state): State<WebSocketState>,
135) -> impl IntoResponse {
136 ws.on_upgrade(move |socket| handle_mcp_socket(socket, state))
137}
138
139async fn handle_mcp_socket(socket: WebSocket, state: WebSocketState) {
140 let (mut sender, mut receiver) = socket.split();
141 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
142
143 let mut send_task = tokio::spawn(async move {
145 while let Some(msg) = rx.recv().await {
146 if sender.send(msg).await.is_err() {
147 break;
148 }
149 }
150 });
151
152 let mut project_path: Option<String> = None;
154
155 let state_for_recv = state.clone();
157
158 let heartbeat_tx = tx.clone();
160
161 let mut heartbeat_task = tokio::spawn(async move {
163 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
164 loop {
165 interval.tick().await;
166 let ping_msg = McpResponse::Pong; if heartbeat_tx
168 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
169 .is_err()
170 {
171 break;
173 }
174 tracing::trace!("Sent heartbeat to MCP client");
175 }
176 });
177
178 let mut recv_task = tokio::spawn(async move {
180 while let Some(Ok(msg)) = receiver.next().await {
181 match msg {
182 Message::Text(text) => {
183 match serde_json::from_str::<McpMessage>(&text) {
185 Ok(McpMessage::Register { project }) => {
186 tracing::info!("MCP registering project: {}", project.name);
187
188 let path = project.path.clone();
190 let conn = McpConnection {
191 tx: tx.clone(),
192 project: project.clone(),
193 connected_at: chrono::Utc::now(),
194 };
195
196 state_for_recv
197 .mcp_connections
198 .write()
199 .await
200 .insert(path.clone(), conn);
201 project_path = Some(path.clone());
202
203 let project_path_buf = std::path::PathBuf::from(&path);
205 match crate::dashboard::registry::ProjectRegistry::load() {
206 Ok(mut registry) => {
207 if let Err(e) = registry.register_mcp_connection(
208 &project_path_buf,
209 Some("mcp-client".to_string()),
210 ) {
211 tracing::warn!(
212 "Failed to update Registry for MCP connection: {}",
213 e
214 );
215 } else {
216 tracing::info!(
217 "✓ Updated Registry: {} is now mcp_connected=true",
218 project.name
219 );
220 }
221 },
222 Err(e) => {
223 tracing::warn!("Failed to load Registry: {}", e);
224 },
225 }
226
227 let response = McpResponse::Registered { success: true };
229 let _ =
230 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
231
232 let ui_msg = UiMessage::ProjectOnline { project };
234 state_for_recv
235 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
236 .await;
237 },
238 Ok(McpMessage::Ping) => {
239 let response = McpResponse::Pong;
241 let _ =
242 tx.send(Message::Text(serde_json::to_string(&response).unwrap()));
243 },
244 Err(e) => {
245 tracing::warn!("Failed to parse MCP message: {}", e);
246 },
247 }
248 },
249 Message::Close(_) => {
250 break;
251 },
252 _ => {},
253 }
254 }
255
256 project_path
257 });
258
259 tokio::select! {
261 _ = (&mut send_task) => {
262 recv_task.abort();
263 heartbeat_task.abort();
264 }
265 project_path_result = (&mut recv_task) => {
266 send_task.abort();
267 heartbeat_task.abort();
268 if let Ok(Some(path)) = project_path_result {
269 state.mcp_connections.write().await.remove(&path);
271
272 let project_path_buf = std::path::PathBuf::from(&path);
274 match crate::dashboard::registry::ProjectRegistry::load() {
275 Ok(mut registry) => {
276 if let Err(e) = registry.unregister_mcp_connection(&project_path_buf) {
277 tracing::warn!("Failed to update Registry for MCP disconnection: {}", e);
278 } else {
279 tracing::info!("✓ Updated Registry: {} is now mcp_connected=false", path);
280 }
281 }
282 Err(e) => {
283 tracing::warn!("Failed to load Registry: {}", e);
284 }
285 }
286
287 let ui_msg = UiMessage::ProjectOffline { project_path: path.clone() };
289 state
290 .broadcast_to_ui(&serde_json::to_string(&ui_msg).unwrap())
291 .await;
292
293 tracing::info!("MCP disconnected: {}", path);
294 }
295 }
296 _ = (&mut heartbeat_task) => {
297 send_task.abort();
298 recv_task.abort();
299 }
300 }
301}
302
303pub async fn handle_ui_websocket(
305 ws: WebSocketUpgrade,
306 State(state): State<WebSocketState>,
307) -> impl IntoResponse {
308 ws.on_upgrade(move |socket| handle_ui_socket(socket, state))
309}
310
311async fn handle_ui_socket(socket: WebSocket, state: WebSocketState) {
312 let (mut sender, mut receiver) = socket.split();
313 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
314
315 let mut send_task = tokio::spawn(async move {
317 while let Some(msg) = rx.recv().await {
318 if sender.send(msg).await.is_err() {
319 break;
320 }
321 }
322 });
323
324 let projects = state.get_online_projects().await;
326 let init_msg = UiMessage::Init { projects };
327 let _ = tx.send(Message::Text(serde_json::to_string(&init_msg).unwrap()));
328
329 let conn = UiConnection {
331 tx: tx.clone(),
332 connected_at: chrono::Utc::now(),
333 };
334 let conn_index = {
335 let mut connections = state.ui_connections.write().await;
336 connections.push(conn);
337 connections.len() - 1
338 };
339
340 tracing::info!("UI client connected");
341
342 let heartbeat_tx = tx.clone();
344
345 let mut heartbeat_task = tokio::spawn(async move {
347 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
348 loop {
349 interval.tick().await;
350 let ping_msg = UiMessage::Ping;
351 if heartbeat_tx
352 .send(Message::Text(serde_json::to_string(&ping_msg).unwrap()))
353 .is_err()
354 {
355 break;
357 }
358 tracing::trace!("Sent heartbeat ping to UI client");
359 }
360 });
361
362 let mut recv_task = tokio::spawn(async move {
364 while let Some(Ok(msg)) = receiver.next().await {
365 match msg {
366 Message::Text(text) => {
367 tracing::trace!("Received from UI: {}", text);
369 },
370 Message::Pong(_) => {
371 tracing::trace!("Received pong from UI");
372 },
373 Message::Close(_) => {
374 break;
375 },
376 _ => {},
377 }
378 }
379 });
380
381 tokio::select! {
383 _ = (&mut send_task) => {
384 recv_task.abort();
385 heartbeat_task.abort();
386 }
387 _ = (&mut recv_task) => {
388 send_task.abort();
389 heartbeat_task.abort();
390 }
391 _ = (&mut heartbeat_task) => {
392 send_task.abort();
393 recv_task.abort();
394 }
395 }
396
397 state.ui_connections.write().await.swap_remove(conn_index);
399 tracing::info!("UI client disconnected");
400}