Skip to main content

nexus_memory_web/
websocket.rs

1//! WebSocket handler for real-time updates
2
3use axum::{
4    extract::{State, WebSocketUpgrade},
5    http::{HeaderMap, StatusCode},
6    response::{IntoResponse, Response},
7};
8use futures::{sink::SinkExt, stream::StreamExt};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11use tracing::{error, info, warn};
12use url::Url;
13
14use crate::{models::WebSocketMessage, state::AppState};
15
16/// Validate that the request Origin header matches an exact local origin.
17/// Parses the Origin as a URL and compares scheme + host exactly to prevent
18/// prefix-spoofing attacks (e.g. http://localhost.evil.com).
19/// Missing Origin headers are rejected to enforce the local-only trust model.
20fn is_local_origin(headers: &HeaderMap) -> bool {
21    let origin_str = match headers.get("origin").and_then(|v| v.to_str().ok()) {
22        Some(s) => s,
23        None => return false, // Reject missing Origin — non-browser clients must send it
24    };
25    match Url::parse(origin_str) {
26        Ok(url) => {
27            let host = url.host_str().unwrap_or("");
28            let scheme = url.scheme();
29            (scheme == "http" || scheme == "https") && (host == "127.0.0.1" || host == "localhost")
30        }
31        Err(_) => false, // Malformed origins are rejected
32    }
33}
34
35/// WebSocket connection handler
36pub async fn websocket_handler(
37    ws: WebSocketUpgrade,
38    headers: HeaderMap,
39    State(state): State<Arc<RwLock<AppState>>>,
40) -> Response {
41    // Reject cross-origin WebSocket upgrades
42    if !is_local_origin(&headers) {
43        return (
44            StatusCode::FORBIDDEN,
45            "WebSocket connections are only allowed from local origins",
46        )
47            .into_response();
48    }
49
50    ws.on_upgrade(move |socket| handle_socket(socket, state))
51}
52
53/// Handle a WebSocket connection
54async fn handle_socket(socket: axum::extract::ws::WebSocket, state: Arc<RwLock<AppState>>) {
55    let (mut sender, mut receiver) = socket.split();
56
57    // Subscribe to broadcast channel
58    let mut broadcast_rx = {
59        let state = state.read().await;
60        state.subscribe_ws()
61    };
62
63    // Channel for direct replies (pong, etc.) from the message handler to the send task
64    let (direct_tx, mut direct_rx) = mpsc::channel::<WebSocketMessage>(16);
65
66    info!("WebSocket client connected");
67
68    // Spawn task to forward messages to this client.
69    // Handles both broadcast events and direct replies.
70    let send_task = tokio::spawn(async move {
71        loop {
72            // `biased` ensures direct replies (e.g. pong) always preempt
73            // broadcast events when both channels are ready simultaneously.
74            tokio::select! {
75                biased;
76                // Priority: direct replies first (e.g. pong responses)
77                direct_msg = direct_rx.recv() => {
78                    match direct_msg {
79                        Some(msg) => {
80                            if send_ws_message(&mut sender, &msg).await.is_err() {
81                                break;
82                            }
83                        }
84                        None => break, // channel closed
85                    }
86                }
87                // Broadcast events
88                broadcast_result = broadcast_rx.recv() => {
89                    match broadcast_result {
90                        Ok(msg) => {
91                            if send_ws_message(&mut sender, &msg).await.is_err() {
92                                break;
93                            }
94                        }
95                        Err(broadcast::error::RecvError::Lagged(n)) => {
96                            warn!("WebSocket client lagged behind, dropped {} messages", n);
97                        }
98                        Err(broadcast::error::RecvError::Closed) => {
99                            break;
100                        }
101                    }
102                }
103            }
104        }
105    });
106
107    // Handle incoming messages from client
108    while let Some(msg) = receiver.next().await {
109        match msg {
110            Ok(axum::extract::ws::Message::Text(text)) => {
111                // Parse the message
112                match serde_json::from_str::<WebSocketMessage>(&text) {
113                    Ok(ws_msg) => {
114                        // Handle ping/pong
115                        match ws_msg.message_type {
116                            crate::models::WebSocketMessageType::Ping => {
117                                let pong = WebSocketMessage::pong();
118                                if direct_tx.send(pong).await.is_err() {
119                                    break;
120                                }
121                            }
122                            _ => {
123                                // Handle other message types if needed
124                            }
125                        }
126                    }
127                    Err(e) => {
128                        warn!("Invalid WebSocket message received: {}", e);
129                    }
130                }
131            }
132            Ok(axum::extract::ws::Message::Close(_)) => {
133                info!("WebSocket client disconnected");
134                break;
135            }
136            Ok(_) => {
137                // Ignore other message types
138            }
139            Err(e) => {
140                error!("WebSocket error: {}", e);
141                break;
142            }
143        }
144    }
145
146    // Abort the send task when client disconnects
147    send_task.abort();
148    info!("WebSocket connection closed");
149}
150
151/// Serialize and send a single WebSocketMessage to the client.
152async fn send_ws_message(
153    sender: &mut futures::stream::SplitSink<
154        axum::extract::ws::WebSocket,
155        axum::extract::ws::Message,
156    >,
157    msg: &WebSocketMessage,
158) -> Result<(), axum::Error> {
159    let json = match serde_json::to_string(msg) {
160        Ok(j) => j,
161        Err(e) => {
162            error!("Failed to serialize WebSocket message: {}", e);
163            return Ok(()); // skip bad message, keep connection alive
164        }
165    };
166
167    sender
168        .send(axum::extract::ws::Message::Text(json.into()))
169        .await
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::models::WebSocketMessageType;
176    use crate::WebDashboard;
177    use futures_util::StreamExt;
178    use http::HeaderValue;
179    use tokio::net::TcpListener;
180    use tokio_tungstenite::tungstenite::protocol::Message as TungsteniteMessage;
181
182    #[test]
183    fn test_is_local_origin_accepts_https_localhost() {
184        let mut headers = HeaderMap::new();
185        headers.insert("origin", HeaderValue::from_static("https://localhost:8768"));
186
187        assert!(is_local_origin(&headers));
188    }
189
190    /// Verifies that a WebSocket `ping` from one client receives a direct `pong`
191    /// reply to that client only, and is NOT broadcast to other connected clients.
192    ///
193    /// Marked `#[ignore]` because it requires raw TCP socket binding which can
194    /// fail in restricted CI environments (PermissionDenied on ephemeral ports).
195    /// Can be run locally with `--include-ignored` when socket access is available.
196    #[tokio::test]
197    #[ignore = "requires raw TCP bind on ephemeral port; flaky in restricted environments"]
198    async fn test_ping_pong_isolation_direct_reply_only() {
199        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
200            .await
201            .expect("connect to in-memory db");
202        nexus_storage::migrations::run_migrations(&pool)
203            .await
204            .expect("run migrations");
205
206        let mut storage = nexus_storage::StorageManager::new(pool.clone());
207        storage.initialize().await.expect("initialize storage");
208
209        let dashboard = WebDashboard::new(storage, nexus_orchestrator::Orchestrator::default())
210            .await
211            .expect("create dashboard");
212
213        // Bind to port 0 to get a random available port
214        let listener = TcpListener::bind("127.0.0.1:0")
215            .await
216            .expect("bind to random port");
217        let addr = listener.local_addr().expect("get local addr");
218
219        // Spawn the server
220        let server_handle = tokio::spawn(async move {
221            axum::serve(listener, dashboard.router).await.unwrap();
222        });
223
224        // Give the server a moment to start
225        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
226
227        // Connect two WebSocket clients
228        let url_a = format!("ws://127.0.0.1:{}/ws", addr.port());
229        let url_b = format!("ws://127.0.0.1:{}/ws", addr.port());
230
231        let (mut ws_a, _) = tokio_tungstenite::connect_async(&url_a)
232            .await
233            .expect("client A connect");
234        let (mut ws_b, _) = tokio_tungstenite::connect_async(&url_b)
235            .await
236            .expect("client B connect");
237
238        // Drain any initial messages from both clients (subscription setup noise)
239        drain_messages(&mut ws_a, std::time::Duration::from_millis(200)).await;
240        drain_messages(&mut ws_b, std::time::Duration::from_millis(200)).await;
241
242        // Client A sends a ping
243        let ping_msg = WebSocketMessage::ping();
244        let ping_json = serde_json::to_string(&ping_msg).expect("serialize ping");
245        ws_a.send(TungsteniteMessage::Text(ping_json.into()))
246            .await
247            .expect("send ping from A");
248
249        // Client A should receive the pong directly
250        let reply_a = tokio::time::timeout(std::time::Duration::from_secs(2), ws_a.next())
251            .await
252            .expect("timeout waiting for pong on A")
253            .expect("no message on A")
254            .expect("error on A");
255
256        let reply_text = match reply_a {
257            TungsteniteMessage::Text(t) => t.to_string(),
258            other => panic!("expected text message on A, got: {:?}", other),
259        };
260
261        let reply_msg: WebSocketMessage =
262            serde_json::from_str(&reply_text).expect("parse pong on A");
263        assert!(
264            matches!(reply_msg.message_type, WebSocketMessageType::Pong),
265            "expected Pong message type, got: {:?}",
266            reply_msg.message_type
267        );
268
269        // Client B should NOT receive the pong (it was a ping from A)
270        // Wait a short period and verify no pong arrives on B
271        let b_reply =
272            tokio::time::timeout(std::time::Duration::from_millis(500), ws_b.next()).await;
273
274        assert!(
275            b_reply.is_err(),
276            "Client B received a message when it should not have \
277             (ping from A must not be broadcast)"
278        );
279
280        // Clean up
281        server_handle.abort();
282    }
283
284    /// Drain any pending messages from a WebSocket connection within a timeout.
285    async fn drain_messages(
286        ws: &mut tokio_tungstenite::WebSocketStream<
287            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
288        >,
289        timeout: std::time::Duration,
290    ) {
291        loop {
292            match tokio::time::timeout(timeout, ws.next()).await {
293                Ok(Some(Ok(_))) => continue,
294                Ok(Some(Err(_))) => break,
295                Ok(None) => break,
296                Err(_) => break, // timeout = no more pending messages
297            }
298        }
299    }
300}