1use axum::extract::ws::{Message, WebSocket};
2use futures_util::StreamExt;
3use serde::{Deserialize, Serialize};
4
5use crate::bridge::BridgeClient;
6
7#[derive(Deserialize)]
8struct IncomingMessage {
9 message: String,
10}
11
12#[derive(Serialize)]
13#[serde(untagged)]
14enum OutgoingMessage {
15 Status { status: String },
16 Response { response: String },
17 Error { error: String },
18}
19
20pub async fn handle_socket(mut socket: WebSocket, bridge: BridgeClient) {
22 while let Some(Ok(msg)) = socket.next().await {
23 let text = match msg {
24 Message::Text(t) => t,
25 Message::Close(_) => break,
26 _ => continue,
27 };
28
29 let incoming: IncomingMessage = match serde_json::from_str(&text) {
30 Ok(m) => m,
31 Err(_) => {
32 let _ = send_json(
33 &mut socket,
34 &OutgoingMessage::Error {
35 error: "invalid message format".into(),
36 },
37 )
38 .await;
39 continue;
40 }
41 };
42
43 if incoming.message.trim().is_empty() {
44 continue;
45 }
46
47 let _ = send_json(
49 &mut socket,
50 &OutgoingMessage::Status {
51 status: "thinking".into(),
52 },
53 )
54 .await;
55
56 match bridge.send(&incoming.message).await {
58 Ok(response) => {
59 let _ = send_json(&mut socket, &OutgoingMessage::Response { response }).await;
60 }
61 Err(e) => {
62 tracing::error!(error = %e, "bridge-echo request failed");
63 let _ = send_json(&mut socket, &OutgoingMessage::Error { error: e }).await;
64 }
65 }
66 }
67
68 tracing::info!("WebSocket connection closed");
69}
70
71async fn send_json(socket: &mut WebSocket, msg: &OutgoingMessage) -> Result<(), axum::Error> {
72 let text = serde_json::to_string(msg).unwrap_or_default();
73 socket.send(Message::text(text)).await
74}