Skip to main content

opendev_web/
websocket.rs

1//! WebSocket handler for real-time communication.
2
3use axum::extract::ws::{Message, WebSocket};
4use axum::extract::{State, WebSocketUpgrade};
5use axum::response::IntoResponse;
6use futures::SinkExt;
7use futures::stream::StreamExt;
8use tracing::{debug, error, info, warn};
9
10use crate::protocol::WsMessageType;
11use crate::state::{AppState, WsBroadcast};
12
13/// WebSocket upgrade handler.
14pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
15    ws.on_upgrade(move |socket| handle_socket(socket, state))
16}
17
18/// Handle an individual WebSocket connection.
19async fn handle_socket(socket: WebSocket, state: AppState) {
20    let (mut sender, mut receiver) = socket.split();
21
22    // Subscribe to broadcast channel.
23    let mut broadcast_rx = state.ws_subscribe();
24
25    // Spawn task to forward broadcasts to this client.
26    let send_task = tokio::spawn(async move {
27        while let Ok(msg) = broadcast_rx.recv().await {
28            match serde_json::to_string(&msg) {
29                Ok(text) => {
30                    if sender.send(Message::Text(text.into())).await.is_err() {
31                        break;
32                    }
33                }
34                Err(e) => {
35                    error!("Failed to serialize broadcast message: {}", e);
36                }
37            }
38        }
39    });
40
41    // Receive messages from this client.
42    while let Some(Ok(msg)) = receiver.next().await {
43        match msg {
44            Message::Text(text) => {
45                handle_client_message(&state, &text).await;
46            }
47            Message::Close(_) => {
48                info!("WebSocket client disconnected");
49                break;
50            }
51            _ => {}
52        }
53    }
54
55    // Clean up the send task.
56    send_task.abort();
57    info!("WebSocket connection closed");
58}
59
60/// Handle a text message from a WebSocket client.
61async fn handle_client_message(state: &AppState, text: &str) {
62    let parsed: serde_json::Value = match serde_json::from_str(text) {
63        Ok(v) => v,
64        Err(e) => {
65            warn!("Invalid WebSocket message JSON: {}", e);
66            return;
67        }
68    };
69
70    let msg_type_str = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
71
72    let msg_type = WsMessageType::from_str_opt(msg_type_str);
73
74    match msg_type {
75        Some(WsMessageType::Ping) => {
76            state.broadcast(WsBroadcast {
77                msg_type: WsMessageType::Pong.as_str().to_string(),
78                data: serde_json::Value::Null,
79            });
80        }
81        Some(WsMessageType::Query) => {
82            handle_query(state, &parsed).await;
83        }
84        Some(WsMessageType::Approve) => {
85            handle_approval(state, &parsed).await;
86        }
87        Some(WsMessageType::AskUserResponse) => {
88            handle_ask_user_response(state, &parsed).await;
89        }
90        Some(WsMessageType::PlanApprovalResponse) => {
91            handle_plan_approval_response(state, &parsed).await;
92        }
93        Some(WsMessageType::Interrupt) => {
94            handle_interrupt(state).await;
95        }
96        _ => {
97            if !msg_type_str.is_empty() {
98                warn!("Unknown WebSocket message type: {}", msg_type_str);
99            }
100            state.broadcast(WsBroadcast {
101                msg_type: WsMessageType::Error.as_str().to_string(),
102                data: serde_json::json!({
103                    "message": format!("Unknown message type: {}", msg_type_str),
104                }),
105            });
106        }
107    }
108}
109
110/// Handle a query message from a WebSocket client.
111async fn handle_query(state: &AppState, data: &serde_json::Value) {
112    let message = data
113        .get("data")
114        .and_then(|d| d.get("message"))
115        .and_then(|m| m.as_str());
116    let session_id = data
117        .get("data")
118        .and_then(|d| d.get("session_id"))
119        .and_then(|s| s.as_str());
120
121    let message = match message {
122        Some(m) if !m.trim().is_empty() => m.trim(),
123        _ => {
124            state.broadcast(WsBroadcast {
125                msg_type: WsMessageType::Error.as_str().to_string(),
126                data: serde_json::json!({"message": "Missing or empty message field"}),
127            });
128            return;
129        }
130    };
131
132    // Resolve session ID: use provided or fall back to current.
133    let session_id = match session_id {
134        Some(id) => id.to_string(),
135        None => match state.current_session_id().await {
136            Some(id) => id,
137            None => {
138                state.broadcast(WsBroadcast {
139                    msg_type: WsMessageType::Error.as_str().to_string(),
140                    data: serde_json::json!({"message": "No active session"}),
141                });
142                return;
143            }
144        },
145    };
146
147    // Bridge mode: route to TUI injector instead of agent executor.
148    if state.is_bridge_guarded(&session_id).await {
149        // In bridge mode, broadcast the user message then inject into the
150        // TUI's queue (same injection mechanism used for live messages).
151        state.broadcast(WsBroadcast {
152            msg_type: WsMessageType::UserMessage.as_str().to_string(),
153            data: serde_json::json!({
154                "role": "user",
155                "content": message,
156                "session_id": session_id,
157            }),
158        });
159
160        match state
161            .try_inject_message(&session_id, message.to_string())
162            .await
163        {
164            Ok(()) => {}
165            Err(e) => {
166                state.broadcast(WsBroadcast {
167                    msg_type: WsMessageType::Error.as_str().to_string(),
168                    data: serde_json::json!({
169                        "message": format!("Bridge mode injection failed: {}", e),
170                    }),
171                });
172            }
173        }
174        return;
175    }
176
177    // If session is already running, inject into live queue.
178    if state.is_session_running(&session_id).await {
179        match state
180            .try_inject_message(&session_id, message.to_string())
181            .await
182        {
183            Ok(()) => {
184                state.broadcast(WsBroadcast {
185                    msg_type: WsMessageType::UserMessage.as_str().to_string(),
186                    data: serde_json::json!({
187                        "role": "user",
188                        "content": message,
189                        "session_id": session_id,
190                        "injected": true,
191                    }),
192                });
193            }
194            Err(e) => {
195                state.broadcast(WsBroadcast {
196                    msg_type: WsMessageType::Error.as_str().to_string(),
197                    data: serde_json::json!({
198                        "message": e,
199                        "session_id": session_id,
200                    }),
201                });
202            }
203        }
204        return;
205    }
206
207    // Broadcast user message.
208    state.broadcast(WsBroadcast {
209        msg_type: WsMessageType::UserMessage.as_str().to_string(),
210        data: serde_json::json!({
211            "role": "user",
212            "content": message,
213            "session_id": session_id,
214        }),
215    });
216
217    // Fire the agent executor in the background (if set).
218    if let Some(executor) = state.agent_executor().await {
219        let state_clone = state.clone();
220        let message_owned = message.to_string();
221        let session_id_owned = session_id.clone();
222        tokio::spawn(async move {
223            if let Err(e) = executor
224                .execute_query(message_owned, session_id_owned, state_clone)
225                .await
226            {
227                error!("Agent executor error: {}", e);
228            }
229        });
230    } else {
231        debug!(
232            "Query received for session {} but no agent executor is set: {}",
233            session_id, message
234        );
235    }
236}
237
238/// Handle an approval response from a WebSocket client.
239async fn handle_approval(state: &AppState, data: &serde_json::Value) {
240    let approval_data = data.get("data").cloned().unwrap_or_default();
241    let approval_id = approval_data
242        .get("approvalId")
243        .and_then(|v| v.as_str())
244        .unwrap_or("");
245    let approved = approval_data
246        .get("approved")
247        .and_then(|v| v.as_bool())
248        .unwrap_or(false);
249    let auto_approve = approval_data
250        .get("autoApprove")
251        .and_then(|v| v.as_bool())
252        .unwrap_or(false);
253
254    if approval_id.is_empty() {
255        state.broadcast(WsBroadcast {
256            msg_type: WsMessageType::Error.as_str().to_string(),
257            data: serde_json::json!({"message": "Invalid approval data"}),
258        });
259        return;
260    }
261
262    let resolved = state
263        .resolve_approval(approval_id, approved, auto_approve)
264        .await;
265
266    if let Some(approval) = resolved {
267        info!("Approval {} resolved: approved={}", approval_id, approved);
268        state.broadcast(WsBroadcast {
269            msg_type: WsMessageType::ApprovalResolved.as_str().to_string(),
270            data: serde_json::json!({
271                "approvalId": approval_id,
272                "approved": approved,
273                "session_id": approval.session_id,
274            }),
275        });
276    } else {
277        warn!("Approval {} not found", approval_id);
278    }
279}
280
281/// Handle an ask-user response from a WebSocket client.
282async fn handle_ask_user_response(state: &AppState, data: &serde_json::Value) {
283    let response_data = data.get("data").cloned().unwrap_or_default();
284    let request_id = response_data
285        .get("requestId")
286        .and_then(|v| v.as_str())
287        .unwrap_or("");
288    let answers = response_data.get("answers").cloned();
289    let cancelled = response_data
290        .get("cancelled")
291        .and_then(|v| v.as_bool())
292        .unwrap_or(false);
293
294    if request_id.is_empty() {
295        state.broadcast(WsBroadcast {
296            msg_type: WsMessageType::Error.as_str().to_string(),
297            data: serde_json::json!({"message": "Invalid ask-user response data"}),
298        });
299        return;
300    }
301
302    let resolved = state.resolve_ask_user(request_id, answers, cancelled).await;
303
304    if let Some(ask_user) = resolved {
305        info!("Ask-user {} resolved", request_id);
306        state.broadcast(WsBroadcast {
307            msg_type: WsMessageType::AskUserResolved.as_str().to_string(),
308            data: serde_json::json!({
309                "requestId": request_id,
310                "session_id": ask_user.session_id,
311            }),
312        });
313    } else {
314        warn!("Ask-user request {} not found", request_id);
315    }
316}
317
318/// Handle a plan approval response from a WebSocket client.
319async fn handle_plan_approval_response(state: &AppState, data: &serde_json::Value) {
320    let response_data = data.get("data").cloned().unwrap_or_default();
321    let request_id = response_data
322        .get("requestId")
323        .and_then(|v| v.as_str())
324        .unwrap_or("");
325    let action = response_data
326        .get("action")
327        .and_then(|v| v.as_str())
328        .unwrap_or("reject")
329        .to_string();
330    let feedback = response_data
331        .get("feedback")
332        .and_then(|v| v.as_str())
333        .unwrap_or("")
334        .to_string();
335
336    if request_id.is_empty() {
337        state.broadcast(WsBroadcast {
338            msg_type: WsMessageType::Error.as_str().to_string(),
339            data: serde_json::json!({"message": "Invalid plan approval response data"}),
340        });
341        return;
342    }
343
344    let resolved = state
345        .resolve_plan_approval(request_id, action.clone(), feedback)
346        .await;
347
348    if let Some(plan_approval) = resolved {
349        info!("Plan approval {} resolved: action={}", request_id, action);
350        state.broadcast(WsBroadcast {
351            msg_type: WsMessageType::PlanApprovalResolved.as_str().to_string(),
352            data: serde_json::json!({
353                "requestId": request_id,
354                "action": action,
355                "session_id": plan_approval.session_id,
356            }),
357        });
358    } else {
359        warn!("Plan approval request {} not found", request_id);
360    }
361}
362
363/// Handle an interrupt request from a WebSocket client.
364async fn handle_interrupt(state: &AppState) {
365    info!("Interrupt requested via WebSocket");
366    state.request_interrupt().await;
367
368    state.broadcast(WsBroadcast {
369        msg_type: WsMessageType::StatusUpdate.as_str().to_string(),
370        data: serde_json::json!({
371            "interrupted": true,
372        }),
373    });
374}