Skip to main content

nexus_memory_web/
websocket.rs

1//! WebSocket handler for real-time updates
2
3use axum::{
4    extract::{State, WebSocketUpgrade},
5    response::Response,
6};
7use futures::{sink::SinkExt, stream::StreamExt};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{error, info, warn};
11
12use crate::{models::WebSocketMessage, state::AppState};
13
14/// WebSocket connection handler
15pub async fn websocket_handler(
16    ws: WebSocketUpgrade,
17    State(state): State<Arc<RwLock<AppState>>>,
18) -> Response {
19    ws.on_upgrade(move |socket| handle_socket(socket, state))
20}
21
22/// Handle a WebSocket connection
23async fn handle_socket(socket: axum::extract::ws::WebSocket, state: Arc<RwLock<AppState>>) {
24    let (mut sender, mut receiver) = socket.split();
25
26    // Subscribe to broadcast channel
27    let mut broadcast_rx = {
28        let state = state.read().await;
29        state.subscribe_ws()
30    };
31
32    info!("WebSocket client connected");
33
34    // Spawn task to forward broadcast messages to this client
35    let send_task = tokio::spawn(async move {
36        loop {
37            match broadcast_rx.recv().await {
38                Ok(msg) => {
39                    let json = match serde_json::to_string(&msg) {
40                        Ok(j) => j,
41                        Err(e) => {
42                            error!("Failed to serialize WebSocket message: {}", e);
43                            continue;
44                        }
45                    };
46
47                    if sender
48                        .send(axum::extract::ws::Message::Text(json.into()))
49                        .await
50                        .is_err()
51                    {
52                        break;
53                    }
54                }
55                Err(e) => {
56                    error!("Broadcast receive error: {}", e);
57                    break;
58                }
59            }
60        }
61    });
62
63    // Handle incoming messages from client
64    while let Some(msg) = receiver.next().await {
65        match msg {
66            Ok(axum::extract::ws::Message::Text(text)) => {
67                // Parse the message
68                match serde_json::from_str::<WebSocketMessage>(&text) {
69                    Ok(ws_msg) => {
70                        // Handle ping/pong
71                        match ws_msg.message_type {
72                            crate::models::WebSocketMessageType::Ping => {
73                                let pong = WebSocketMessage::pong();
74                                let app_state = state.read().await;
75                                if let Err(e) = app_state.broadcast_ws(pong) {
76                                    warn!("Failed to broadcast WebSocket pong: {}", e);
77                                }
78                            }
79                            _ => {
80                                // Handle other message types if needed
81                            }
82                        }
83                    }
84                    Err(e) => {
85                        warn!("Invalid WebSocket message received: {}", e);
86                    }
87                }
88            }
89            Ok(axum::extract::ws::Message::Close(_)) => {
90                info!("WebSocket client disconnected");
91                break;
92            }
93            Ok(_) => {
94                // Ignore other message types
95            }
96            Err(e) => {
97                error!("WebSocket error: {}", e);
98                break;
99            }
100        }
101    }
102
103    // Abort the send task when client disconnects
104    send_task.abort();
105    info!("WebSocket connection closed");
106}