nexus_memory_web/
websocket.rs1use axum::{
4 extract::{State, WebSocketUpgrade},
5 response::Response,
6};
7use futures::{sink::SinkExt, stream::StreamExt};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{error, info, warn};
11
12use crate::{models::WebSocketMessage, state::AppState};
13
14pub async fn websocket_handler(
16 ws: WebSocketUpgrade,
17 State(state): State<Arc<RwLock<AppState>>>,
18) -> Response {
19 ws.on_upgrade(move |socket| handle_socket(socket, state))
20}
21
22async fn handle_socket(socket: axum::extract::ws::WebSocket, state: Arc<RwLock<AppState>>) {
24 let (mut sender, mut receiver) = socket.split();
25
26 let mut broadcast_rx = {
28 let state = state.read().await;
29 state.subscribe_ws()
30 };
31
32 info!("WebSocket client connected");
33
34 let send_task = tokio::spawn(async move {
36 loop {
37 match broadcast_rx.recv().await {
38 Ok(msg) => {
39 let json = match serde_json::to_string(&msg) {
40 Ok(j) => j,
41 Err(e) => {
42 error!("Failed to serialize WebSocket message: {}", e);
43 continue;
44 }
45 };
46
47 if sender
48 .send(axum::extract::ws::Message::Text(json.into()))
49 .await
50 .is_err()
51 {
52 break;
53 }
54 }
55 Err(e) => {
56 error!("Broadcast receive error: {}", e);
57 break;
58 }
59 }
60 }
61 });
62
63 while let Some(msg) = receiver.next().await {
65 match msg {
66 Ok(axum::extract::ws::Message::Text(text)) => {
67 match serde_json::from_str::<WebSocketMessage>(&text) {
69 Ok(ws_msg) => {
70 match ws_msg.message_type {
72 crate::models::WebSocketMessageType::Ping => {
73 let pong = WebSocketMessage::pong();
74 let app_state = state.read().await;
75 if let Err(e) = app_state.broadcast_ws(pong) {
76 warn!("Failed to broadcast WebSocket pong: {}", e);
77 }
78 }
79 _ => {
80 }
82 }
83 }
84 Err(e) => {
85 warn!("Invalid WebSocket message received: {}", e);
86 }
87 }
88 }
89 Ok(axum::extract::ws::Message::Close(_)) => {
90 info!("WebSocket client disconnected");
91 break;
92 }
93 Ok(_) => {
94 }
96 Err(e) => {
97 error!("WebSocket error: {}", e);
98 break;
99 }
100 }
101 }
102
103 send_task.abort();
105 info!("WebSocket connection closed");
106}