nexus_memory_web/
websocket.rs1use axum::{
4 extract::{State, WebSocketUpgrade},
5 http::{HeaderMap, StatusCode},
6 response::{IntoResponse, Response},
7};
8use futures::{sink::SinkExt, stream::StreamExt};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11use tracing::{error, info, warn};
12use url::Url;
13
14use crate::{models::WebSocketMessage, state::AppState};
15
16fn is_local_origin(headers: &HeaderMap) -> bool {
21 let origin_str = match headers.get("origin").and_then(|v| v.to_str().ok()) {
22 Some(s) => s,
23 None => return false, };
25 match Url::parse(origin_str) {
26 Ok(url) => {
27 let host = url.host_str().unwrap_or("");
28 let scheme = url.scheme();
29 (scheme == "http" || scheme == "https") && (host == "127.0.0.1" || host == "localhost")
30 }
31 Err(_) => false, }
33}
34
35pub async fn websocket_handler(
37 ws: WebSocketUpgrade,
38 headers: HeaderMap,
39 State(state): State<Arc<RwLock<AppState>>>,
40) -> Response {
41 if !is_local_origin(&headers) {
43 return (
44 StatusCode::FORBIDDEN,
45 "WebSocket connections are only allowed from local origins",
46 )
47 .into_response();
48 }
49
50 ws.on_upgrade(move |socket| handle_socket(socket, state))
51}
52
53async fn handle_socket(socket: axum::extract::ws::WebSocket, state: Arc<RwLock<AppState>>) {
55 let (mut sender, mut receiver) = socket.split();
56
57 let mut broadcast_rx = {
59 let state = state.read().await;
60 state.subscribe_ws()
61 };
62
63 let (direct_tx, mut direct_rx) = mpsc::channel::<WebSocketMessage>(16);
65
66 info!("WebSocket client connected");
67
68 let send_task = tokio::spawn(async move {
71 loop {
72 tokio::select! {
75 biased;
76 direct_msg = direct_rx.recv() => {
78 match direct_msg {
79 Some(msg) => {
80 if send_ws_message(&mut sender, &msg).await.is_err() {
81 break;
82 }
83 }
84 None => break, }
86 }
87 broadcast_result = broadcast_rx.recv() => {
89 match broadcast_result {
90 Ok(msg) => {
91 if send_ws_message(&mut sender, &msg).await.is_err() {
92 break;
93 }
94 }
95 Err(broadcast::error::RecvError::Lagged(n)) => {
96 warn!("WebSocket client lagged behind, dropped {} messages", n);
97 }
98 Err(broadcast::error::RecvError::Closed) => {
99 break;
100 }
101 }
102 }
103 }
104 }
105 });
106
107 while let Some(msg) = receiver.next().await {
109 match msg {
110 Ok(axum::extract::ws::Message::Text(text)) => {
111 match serde_json::from_str::<WebSocketMessage>(&text) {
113 Ok(ws_msg) => {
114 match ws_msg.message_type {
116 crate::models::WebSocketMessageType::Ping => {
117 let pong = WebSocketMessage::pong();
118 if direct_tx.send(pong).await.is_err() {
119 break;
120 }
121 }
122 _ => {
123 }
125 }
126 }
127 Err(e) => {
128 warn!("Invalid WebSocket message received: {}", e);
129 }
130 }
131 }
132 Ok(axum::extract::ws::Message::Close(_)) => {
133 info!("WebSocket client disconnected");
134 break;
135 }
136 Ok(_) => {
137 }
139 Err(e) => {
140 error!("WebSocket error: {}", e);
141 break;
142 }
143 }
144 }
145
146 send_task.abort();
148 info!("WebSocket connection closed");
149}
150
151async fn send_ws_message(
153 sender: &mut futures::stream::SplitSink<
154 axum::extract::ws::WebSocket,
155 axum::extract::ws::Message,
156 >,
157 msg: &WebSocketMessage,
158) -> Result<(), axum::Error> {
159 let json = match serde_json::to_string(msg) {
160 Ok(j) => j,
161 Err(e) => {
162 error!("Failed to serialize WebSocket message: {}", e);
163 return Ok(()); }
165 };
166
167 sender
168 .send(axum::extract::ws::Message::Text(json.into()))
169 .await
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::models::WebSocketMessageType;
176 use crate::WebDashboard;
177 use futures_util::StreamExt;
178 use http::HeaderValue;
179 use tokio::net::TcpListener;
180 use tokio_tungstenite::tungstenite::protocol::Message as TungsteniteMessage;
181
182 #[test]
183 fn test_is_local_origin_accepts_https_localhost() {
184 let mut headers = HeaderMap::new();
185 headers.insert("origin", HeaderValue::from_static("https://localhost:8768"));
186
187 assert!(is_local_origin(&headers));
188 }
189
190 #[tokio::test]
197 #[ignore = "requires raw TCP bind on ephemeral port; flaky in restricted environments"]
198 async fn test_ping_pong_isolation_direct_reply_only() {
199 let pool = sqlx::SqlitePool::connect("sqlite::memory:")
200 .await
201 .expect("connect to in-memory db");
202 nexus_storage::migrations::run_migrations(&pool)
203 .await
204 .expect("run migrations");
205
206 let mut storage = nexus_storage::StorageManager::new(pool.clone());
207 storage.initialize().await.expect("initialize storage");
208
209 let dashboard = WebDashboard::new(storage, nexus_orchestrator::Orchestrator::default())
210 .await
211 .expect("create dashboard");
212
213 let listener = TcpListener::bind("127.0.0.1:0")
215 .await
216 .expect("bind to random port");
217 let addr = listener.local_addr().expect("get local addr");
218
219 let server_handle = tokio::spawn(async move {
221 axum::serve(listener, dashboard.router).await.unwrap();
222 });
223
224 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
226
227 let url_a = format!("ws://127.0.0.1:{}/ws", addr.port());
229 let url_b = format!("ws://127.0.0.1:{}/ws", addr.port());
230
231 let (mut ws_a, _) = tokio_tungstenite::connect_async(&url_a)
232 .await
233 .expect("client A connect");
234 let (mut ws_b, _) = tokio_tungstenite::connect_async(&url_b)
235 .await
236 .expect("client B connect");
237
238 drain_messages(&mut ws_a, std::time::Duration::from_millis(200)).await;
240 drain_messages(&mut ws_b, std::time::Duration::from_millis(200)).await;
241
242 let ping_msg = WebSocketMessage::ping();
244 let ping_json = serde_json::to_string(&ping_msg).expect("serialize ping");
245 ws_a.send(TungsteniteMessage::Text(ping_json.into()))
246 .await
247 .expect("send ping from A");
248
249 let reply_a = tokio::time::timeout(std::time::Duration::from_secs(2), ws_a.next())
251 .await
252 .expect("timeout waiting for pong on A")
253 .expect("no message on A")
254 .expect("error on A");
255
256 let reply_text = match reply_a {
257 TungsteniteMessage::Text(t) => t.to_string(),
258 other => panic!("expected text message on A, got: {:?}", other),
259 };
260
261 let reply_msg: WebSocketMessage =
262 serde_json::from_str(&reply_text).expect("parse pong on A");
263 assert!(
264 matches!(reply_msg.message_type, WebSocketMessageType::Pong),
265 "expected Pong message type, got: {:?}",
266 reply_msg.message_type
267 );
268
269 let b_reply =
272 tokio::time::timeout(std::time::Duration::from_millis(500), ws_b.next()).await;
273
274 assert!(
275 b_reply.is_err(),
276 "Client B received a message when it should not have \
277 (ping from A must not be broadcast)"
278 );
279
280 server_handle.abort();
282 }
283
284 async fn drain_messages(
286 ws: &mut tokio_tungstenite::WebSocketStream<
287 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
288 >,
289 timeout: std::time::Duration,
290 ) {
291 loop {
292 match tokio::time::timeout(timeout, ws.next()).await {
293 Ok(Some(Ok(_))) => continue,
294 Ok(Some(Err(_))) => break,
295 Ok(None) => break,
296 Err(_) => break, }
298 }
299 }
300}