1use axum::extract::ws::{Message, WebSocket};
4use axum::extract::{State, WebSocketUpgrade};
5use axum::response::IntoResponse;
6use futures::SinkExt;
7use futures::stream::StreamExt;
8use tracing::{debug, error, info, warn};
9
10use crate::protocol::WsMessageType;
11use crate::state::{AppState, WsBroadcast};
12
13pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
15 ws.on_upgrade(move |socket| handle_socket(socket, state))
16}
17
18async fn handle_socket(socket: WebSocket, state: AppState) {
20 let (mut sender, mut receiver) = socket.split();
21
22 let mut broadcast_rx = state.ws_subscribe();
24
25 let send_task = tokio::spawn(async move {
27 while let Ok(msg) = broadcast_rx.recv().await {
28 match serde_json::to_string(&msg) {
29 Ok(text) => {
30 if sender.send(Message::Text(text.into())).await.is_err() {
31 break;
32 }
33 }
34 Err(e) => {
35 error!("Failed to serialize broadcast message: {}", e);
36 }
37 }
38 }
39 });
40
41 while let Some(Ok(msg)) = receiver.next().await {
43 match msg {
44 Message::Text(text) => {
45 handle_client_message(&state, &text).await;
46 }
47 Message::Close(_) => {
48 info!("WebSocket client disconnected");
49 break;
50 }
51 _ => {}
52 }
53 }
54
55 send_task.abort();
57 info!("WebSocket connection closed");
58}
59
60async fn handle_client_message(state: &AppState, text: &str) {
62 let parsed: serde_json::Value = match serde_json::from_str(text) {
63 Ok(v) => v,
64 Err(e) => {
65 warn!("Invalid WebSocket message JSON: {}", e);
66 return;
67 }
68 };
69
70 let msg_type_str = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
71
72 let msg_type = WsMessageType::from_str_opt(msg_type_str);
73
74 match msg_type {
75 Some(WsMessageType::Ping) => {
76 state.broadcast(WsBroadcast {
77 msg_type: WsMessageType::Pong.as_str().to_string(),
78 data: serde_json::Value::Null,
79 });
80 }
81 Some(WsMessageType::Query) => {
82 handle_query(state, &parsed).await;
83 }
84 Some(WsMessageType::Approve) => {
85 handle_approval(state, &parsed).await;
86 }
87 Some(WsMessageType::AskUserResponse) => {
88 handle_ask_user_response(state, &parsed).await;
89 }
90 Some(WsMessageType::PlanApprovalResponse) => {
91 handle_plan_approval_response(state, &parsed).await;
92 }
93 Some(WsMessageType::Interrupt) => {
94 handle_interrupt(state).await;
95 }
96 _ => {
97 if !msg_type_str.is_empty() {
98 warn!("Unknown WebSocket message type: {}", msg_type_str);
99 }
100 state.broadcast(WsBroadcast {
101 msg_type: WsMessageType::Error.as_str().to_string(),
102 data: serde_json::json!({
103 "message": format!("Unknown message type: {}", msg_type_str),
104 }),
105 });
106 }
107 }
108}
109
110async fn handle_query(state: &AppState, data: &serde_json::Value) {
112 let message = data
113 .get("data")
114 .and_then(|d| d.get("message"))
115 .and_then(|m| m.as_str());
116 let session_id = data
117 .get("data")
118 .and_then(|d| d.get("session_id"))
119 .and_then(|s| s.as_str());
120
121 let message = match message {
122 Some(m) if !m.trim().is_empty() => m.trim(),
123 _ => {
124 state.broadcast(WsBroadcast {
125 msg_type: WsMessageType::Error.as_str().to_string(),
126 data: serde_json::json!({"message": "Missing or empty message field"}),
127 });
128 return;
129 }
130 };
131
132 let session_id = match session_id {
134 Some(id) => id.to_string(),
135 None => match state.current_session_id().await {
136 Some(id) => id,
137 None => {
138 state.broadcast(WsBroadcast {
139 msg_type: WsMessageType::Error.as_str().to_string(),
140 data: serde_json::json!({"message": "No active session"}),
141 });
142 return;
143 }
144 },
145 };
146
147 if state.is_bridge_guarded(&session_id).await {
149 state.broadcast(WsBroadcast {
152 msg_type: WsMessageType::UserMessage.as_str().to_string(),
153 data: serde_json::json!({
154 "role": "user",
155 "content": message,
156 "session_id": session_id,
157 }),
158 });
159
160 match state
161 .try_inject_message(&session_id, message.to_string())
162 .await
163 {
164 Ok(()) => {}
165 Err(e) => {
166 state.broadcast(WsBroadcast {
167 msg_type: WsMessageType::Error.as_str().to_string(),
168 data: serde_json::json!({
169 "message": format!("Bridge mode injection failed: {}", e),
170 }),
171 });
172 }
173 }
174 return;
175 }
176
177 if state.is_session_running(&session_id).await {
179 match state
180 .try_inject_message(&session_id, message.to_string())
181 .await
182 {
183 Ok(()) => {
184 state.broadcast(WsBroadcast {
185 msg_type: WsMessageType::UserMessage.as_str().to_string(),
186 data: serde_json::json!({
187 "role": "user",
188 "content": message,
189 "session_id": session_id,
190 "injected": true,
191 }),
192 });
193 }
194 Err(e) => {
195 state.broadcast(WsBroadcast {
196 msg_type: WsMessageType::Error.as_str().to_string(),
197 data: serde_json::json!({
198 "message": e,
199 "session_id": session_id,
200 }),
201 });
202 }
203 }
204 return;
205 }
206
207 state.broadcast(WsBroadcast {
209 msg_type: WsMessageType::UserMessage.as_str().to_string(),
210 data: serde_json::json!({
211 "role": "user",
212 "content": message,
213 "session_id": session_id,
214 }),
215 });
216
217 if let Some(executor) = state.agent_executor().await {
219 let state_clone = state.clone();
220 let message_owned = message.to_string();
221 let session_id_owned = session_id.clone();
222 tokio::spawn(async move {
223 if let Err(e) = executor
224 .execute_query(message_owned, session_id_owned, state_clone)
225 .await
226 {
227 error!("Agent executor error: {}", e);
228 }
229 });
230 } else {
231 debug!(
232 "Query received for session {} but no agent executor is set: {}",
233 session_id, message
234 );
235 }
236}
237
238async fn handle_approval(state: &AppState, data: &serde_json::Value) {
240 let approval_data = data.get("data").cloned().unwrap_or_default();
241 let approval_id = approval_data
242 .get("approvalId")
243 .and_then(|v| v.as_str())
244 .unwrap_or("");
245 let approved = approval_data
246 .get("approved")
247 .and_then(|v| v.as_bool())
248 .unwrap_or(false);
249 let auto_approve = approval_data
250 .get("autoApprove")
251 .and_then(|v| v.as_bool())
252 .unwrap_or(false);
253
254 if approval_id.is_empty() {
255 state.broadcast(WsBroadcast {
256 msg_type: WsMessageType::Error.as_str().to_string(),
257 data: serde_json::json!({"message": "Invalid approval data"}),
258 });
259 return;
260 }
261
262 let resolved = state
263 .resolve_approval(approval_id, approved, auto_approve)
264 .await;
265
266 if let Some(approval) = resolved {
267 info!("Approval {} resolved: approved={}", approval_id, approved);
268 state.broadcast(WsBroadcast {
269 msg_type: WsMessageType::ApprovalResolved.as_str().to_string(),
270 data: serde_json::json!({
271 "approvalId": approval_id,
272 "approved": approved,
273 "session_id": approval.session_id,
274 }),
275 });
276 } else {
277 warn!("Approval {} not found", approval_id);
278 }
279}
280
281async fn handle_ask_user_response(state: &AppState, data: &serde_json::Value) {
283 let response_data = data.get("data").cloned().unwrap_or_default();
284 let request_id = response_data
285 .get("requestId")
286 .and_then(|v| v.as_str())
287 .unwrap_or("");
288 let answers = response_data.get("answers").cloned();
289 let cancelled = response_data
290 .get("cancelled")
291 .and_then(|v| v.as_bool())
292 .unwrap_or(false);
293
294 if request_id.is_empty() {
295 state.broadcast(WsBroadcast {
296 msg_type: WsMessageType::Error.as_str().to_string(),
297 data: serde_json::json!({"message": "Invalid ask-user response data"}),
298 });
299 return;
300 }
301
302 let resolved = state.resolve_ask_user(request_id, answers, cancelled).await;
303
304 if let Some(ask_user) = resolved {
305 info!("Ask-user {} resolved", request_id);
306 state.broadcast(WsBroadcast {
307 msg_type: WsMessageType::AskUserResolved.as_str().to_string(),
308 data: serde_json::json!({
309 "requestId": request_id,
310 "session_id": ask_user.session_id,
311 }),
312 });
313 } else {
314 warn!("Ask-user request {} not found", request_id);
315 }
316}
317
318async fn handle_plan_approval_response(state: &AppState, data: &serde_json::Value) {
320 let response_data = data.get("data").cloned().unwrap_or_default();
321 let request_id = response_data
322 .get("requestId")
323 .and_then(|v| v.as_str())
324 .unwrap_or("");
325 let action = response_data
326 .get("action")
327 .and_then(|v| v.as_str())
328 .unwrap_or("reject")
329 .to_string();
330 let feedback = response_data
331 .get("feedback")
332 .and_then(|v| v.as_str())
333 .unwrap_or("")
334 .to_string();
335
336 if request_id.is_empty() {
337 state.broadcast(WsBroadcast {
338 msg_type: WsMessageType::Error.as_str().to_string(),
339 data: serde_json::json!({"message": "Invalid plan approval response data"}),
340 });
341 return;
342 }
343
344 let resolved = state
345 .resolve_plan_approval(request_id, action.clone(), feedback)
346 .await;
347
348 if let Some(plan_approval) = resolved {
349 info!("Plan approval {} resolved: action={}", request_id, action);
350 state.broadcast(WsBroadcast {
351 msg_type: WsMessageType::PlanApprovalResolved.as_str().to_string(),
352 data: serde_json::json!({
353 "requestId": request_id,
354 "action": action,
355 "session_id": plan_approval.session_id,
356 }),
357 });
358 } else {
359 warn!("Plan approval request {} not found", request_id);
360 }
361}
362
363async fn handle_interrupt(state: &AppState) {
365 info!("Interrupt requested via WebSocket");
366 state.request_interrupt().await;
367
368 state.broadcast(WsBroadcast {
369 msg_type: WsMessageType::StatusUpdate.as_str().to_string(),
370 data: serde_json::json!({
371 "interrupted": true,
372 }),
373 });
374}