1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use serde::Deserialize;
9use subtle::ConstantTimeEq;
10use tokio::sync::broadcast;
11use tokio::time::{Instant, interval};
12
13use crate::ws_ticket::TicketStore;
14
15pub use roboticus_pipeline::event_bus::EventBus;
16
17#[derive(Deserialize)]
18struct WsQuery {
19 ticket: Option<String>,
20}
21
22pub fn ws_route(
30 bus: EventBus,
31 tickets: TicketStore,
32 api_key: Option<String>,
33) -> axum::routing::MethodRouter {
34 let api_key: Option<Arc<str>> = api_key.map(|k| Arc::from(k.as_str()));
35
36 let handler =
37 move |ws: WebSocketUpgrade,
38 headers: axum::http::HeaderMap,
39 axum::extract::ConnectInfo(peer_addr): axum::extract::ConnectInfo<SocketAddr>,
40 axum::extract::Query(query): axum::extract::Query<WsQuery>| {
41 let bus = bus.clone();
42 let tickets = tickets.clone();
43 let api_key = api_key.clone();
44 async move {
45 if !ws_authenticate(
46 &headers,
47 &query,
48 &tickets,
49 api_key.as_deref(),
50 Some(peer_addr),
51 ) {
52 return (StatusCode::UNAUTHORIZED, "Valid API key or ticket required")
53 .into_response();
54 }
55 ws.on_upgrade(move |socket| handle_socket(socket, bus))
56 .into_response()
57 }
58 };
59 axum::routing::get(handler)
60}
61
62fn ws_authenticate(
64 headers: &axum::http::HeaderMap,
65 query: &WsQuery,
66 tickets: &TicketStore,
67 api_key: Option<&str>,
68 peer_addr: Option<SocketAddr>,
69) -> bool {
70 let Some(expected) = api_key else {
73 return peer_addr.is_some_and(|addr| addr.ip().is_loopback());
74 };
75
76 if let Some(val) = headers.get("x-api-key")
78 && let Ok(provided) = val.to_str()
79 && bool::from(provided.as_bytes().ct_eq(expected.as_bytes()))
80 {
81 return true;
82 }
83
84 if let Some(val) = headers.get("authorization")
86 && let Ok(s) = val.to_str()
87 && let Some(token) = s.strip_prefix("Bearer ")
88 && bool::from(token.as_bytes().ct_eq(expected.as_bytes()))
89 {
90 return true;
91 }
92
93 if let Some(ref ticket) = query.ticket
95 && tickets.redeem(ticket)
96 {
97 return true;
98 }
99
100 false
101}
102
103const PING_INTERVAL: Duration = Duration::from_secs(30);
104const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
105
106async fn handle_socket(mut socket: WebSocket, bus: EventBus) {
107 let mut rx = bus.subscribe();
108
109 let welcome = serde_json::json!({
111 "type": "connected",
112 "version": env!("CARGO_PKG_VERSION"),
113 "timestamp": chrono::Utc::now().to_rfc3339(),
114 });
115 if let Err(e) = socket.send(Message::Text(welcome.to_string().into())).await {
116 tracing::debug!(error = %e, "WebSocket welcome send failed");
117 return;
118 }
119
120 let mut ping_timer = interval(PING_INTERVAL);
121 ping_timer.tick().await; let mut last_activity = Instant::now();
123
124 loop {
126 tokio::select! {
127 msg = rx.recv() => {
128 match msg {
129 Ok(event) => {
130 if socket.send(Message::Text(event.into())).await.is_err() {
131 break; }
133 last_activity = Instant::now();
134 }
135 Err(broadcast::error::RecvError::Lagged(n)) => {
136 tracing::warn!(skipped = n, "WebSocket subscriber lagged, skipping lost events");
137 continue;
138 }
139 Err(broadcast::error::RecvError::Closed) => break,
140 }
141 }
142 msg = socket.recv() => {
143 match msg {
144 Some(Ok(Message::Text(text))) => {
145 last_activity = Instant::now();
146 if text.len() > 4096 {
148 tracing::warn!(len = text.len(), "WebSocket message exceeds 4KiB limit, closing");
149 break;
150 }
151 let resp = serde_json::json!({ "type": "ack" });
152 if let Err(e) = socket.send(Message::Text(resp.to_string().into())).await {
153 tracing::debug!(error = %e, "WebSocket ack send failed");
154 break;
155 }
156 }
157 Some(Ok(Message::Ping(data))) => {
158 last_activity = Instant::now();
159 if let Err(e) = socket.send(Message::Pong(data)).await {
160 tracing::debug!(error = %e, "WebSocket pong send failed");
161 break;
162 }
163 }
164 Some(Ok(Message::Pong(_))) => {
165 last_activity = Instant::now();
166 }
167 Some(Ok(Message::Close(_))) | None => break,
168 _ => {}
169 }
170 }
171 _ = ping_timer.tick() => {
172 if last_activity.elapsed() > IDLE_TIMEOUT {
173 tracing::info!("WebSocket idle timeout, closing connection");
174 let _ = socket.send(Message::Close(None)).await;
175 break;
176 }
177 if let Err(e) = socket.send(Message::Ping(vec![].into())).await {
178 tracing::debug!(error = %e, "WebSocket ping send failed");
179 break;
180 }
181 }
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[tokio::test]
191 async fn publish_and_receive() {
192 let bus = EventBus::new(16);
193 let mut rx = bus.subscribe();
194
195 bus.publish("hello".to_string());
196 let msg = rx.recv().await.unwrap();
197 assert_eq!(msg, "hello");
198 }
199
200 #[tokio::test]
201 async fn subscriber_receives_all_events() {
202 let bus = EventBus::new(16);
203 let mut rx = bus.subscribe();
204
205 bus.publish("event-1".to_string());
206 bus.publish("event-2".to_string());
207 bus.publish("event-3".to_string());
208
209 let m1 = rx.recv().await.unwrap();
210 let m2 = rx.recv().await.unwrap();
211 let m3 = rx.recv().await.unwrap();
212
213 assert_eq!(m1, "event-1");
214 assert_eq!(m2, "event-2");
215 assert_eq!(m3, "event-3");
216 }
217
218 #[tokio::test]
219 async fn multiple_subscribers() {
220 let bus = EventBus::new(16);
221 let mut rx1 = bus.subscribe();
222 let mut rx2 = bus.subscribe();
223
224 bus.publish("shared".to_string());
225
226 assert_eq!(rx1.recv().await.unwrap(), "shared");
227 assert_eq!(rx2.recv().await.unwrap(), "shared");
228 }
229
230 #[test]
231 fn publish_without_subscribers_does_not_panic() {
232 let bus = EventBus::new(4);
233 bus.publish("orphan".to_string());
234 }
235
236 #[test]
237 fn ws_route_returns_method_router() {
238 let bus = EventBus::new(256);
239 let tickets = TicketStore::new();
240 let _router = super::ws_route(bus, tickets, None);
241 }
242
243 #[tokio::test]
244 async fn event_bus_publish_subscribe() {
245 let bus = EventBus::new(16);
246 let mut rx = bus.subscribe();
247 bus.publish("hello".to_string());
248 let msg = rx.recv().await.unwrap();
249 assert_eq!(msg, "hello");
250 }
251
252 #[tokio::test]
253 async fn event_bus_multiple_subscribers() {
254 let bus = EventBus::new(16);
255 let mut rx1 = bus.subscribe();
256 let mut rx2 = bus.subscribe();
257 bus.publish("event1".to_string());
258 assert_eq!(rx1.recv().await.unwrap(), "event1");
259 assert_eq!(rx2.recv().await.unwrap(), "event1");
260 }
261
262 #[test]
263 fn event_bus_dropped_subscriber_does_not_block() {
264 let bus = EventBus::new(16);
265 let _rx = bus.subscribe();
266 drop(_rx);
267 bus.publish("should not block".to_string());
268 }
269
270 #[tokio::test]
271 async fn bus_clone_shares_channel() {
272 let bus1 = EventBus::new(16);
273 let bus2 = bus1.clone();
274 let mut rx = bus1.subscribe();
275
276 bus2.publish("from-clone".to_string());
277 let msg = rx.recv().await.unwrap();
278 assert_eq!(msg, "from-clone");
279 }
280
281 #[tokio::test]
282 async fn subscriber_after_publish_misses_earlier_events() {
283 let bus = EventBus::new(16);
284 bus.publish("before-subscribe".to_string());
285
286 let mut rx = bus.subscribe();
287 bus.publish("after-subscribe".to_string());
288
289 let msg = rx.recv().await.unwrap();
290 assert_eq!(msg, "after-subscribe");
291 }
292
293 #[test]
294 fn capacity_overflow_does_not_panic() {
295 let bus = EventBus::new(2);
296 let _rx = bus.subscribe();
297 for i in 0..10 {
298 bus.publish(format!("event-{i}"));
299 }
300 }
301
302 #[tokio::test]
303 async fn publish_json_event_roundtrip() {
304 let bus = EventBus::new(16);
305 let mut rx = bus.subscribe();
306 let event = serde_json::json!({"type": "inference", "model": "gpt-4", "tokens": 42});
307 bus.publish(event.to_string());
308 let msg = rx.recv().await.unwrap();
309 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
310 assert_eq!(parsed["type"], "inference");
311 assert_eq!(parsed["tokens"], 42);
312 }
313
314 #[tokio::test]
315 async fn multiple_publishes_order_preserved() {
316 let bus = EventBus::new(64);
317 let mut rx = bus.subscribe();
318 for i in 0..50 {
319 bus.publish(format!("msg-{i}"));
320 }
321 for i in 0..50 {
322 let msg = rx.recv().await.unwrap();
323 assert_eq!(msg, format!("msg-{i}"));
324 }
325 }
326
327 #[tokio::test]
328 async fn concurrent_publishers() {
329 let bus = EventBus::new(256);
330 let mut rx = bus.subscribe();
331 let bus1 = bus.clone();
332 let bus2 = bus.clone();
333
334 let h1 = tokio::spawn(async move {
335 for i in 0..10 {
336 bus1.publish(format!("a-{i}"));
337 }
338 });
339 let h2 = tokio::spawn(async move {
340 for i in 0..10 {
341 bus2.publish(format!("b-{i}"));
342 }
343 });
344
345 h1.await.unwrap();
346 h2.await.unwrap();
347
348 let mut count = 0;
349 while let Ok(msg) =
350 tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await
351 {
352 msg.unwrap();
353 count += 1;
354 }
355 assert_eq!(count, 20);
356 }
357
358 #[test]
359 fn ws_route_builds_without_panic() {
360 let bus = EventBus::new(4);
361 let tickets = TicketStore::new();
362 let router = axum::Router::new().route("/ws", super::ws_route(bus, tickets, None));
363 let _app = router.into_make_service();
364 }
365
366 #[test]
369 fn ws_auth_no_key_configured_allows_loopback_only() {
370 let headers = axum::http::HeaderMap::new();
371 let query = WsQuery { ticket: None };
372 let tickets = TicketStore::new();
373 let loopback = Some("127.0.0.1:9000".parse::<SocketAddr>().unwrap());
374 let remote = Some("203.0.113.10:9000".parse::<SocketAddr>().unwrap());
375 assert!(ws_authenticate(&headers, &query, &tickets, None, loopback));
376 assert!(!ws_authenticate(&headers, &query, &tickets, None, remote));
377 assert!(!ws_authenticate(&headers, &query, &tickets, None, None));
378 }
379
380 #[test]
381 fn ws_auth_header_x_api_key() {
382 let mut headers = axum::http::HeaderMap::new();
383 headers.insert("x-api-key", "test-key".parse().unwrap());
384 let query = WsQuery { ticket: None };
385 let tickets = TicketStore::new();
386 assert!(ws_authenticate(
387 &headers,
388 &query,
389 &tickets,
390 Some("test-key"),
391 None,
392 ));
393 }
394
395 #[test]
396 fn ws_auth_header_bearer() {
397 let mut headers = axum::http::HeaderMap::new();
398 headers.insert("authorization", "Bearer test-key".parse().unwrap());
399 let query = WsQuery { ticket: None };
400 let tickets = TicketStore::new();
401 assert!(ws_authenticate(
402 &headers,
403 &query,
404 &tickets,
405 Some("test-key"),
406 None,
407 ));
408 }
409
410 #[test]
411 fn ws_auth_valid_ticket() {
412 let headers = axum::http::HeaderMap::new();
413 let tickets = TicketStore::new();
414 let ticket = tickets.issue();
415 let query = WsQuery {
416 ticket: Some(ticket),
417 };
418 assert!(ws_authenticate(
419 &headers,
420 &query,
421 &tickets,
422 Some("test-key"),
423 None,
424 ));
425 }
426
427 #[test]
428 fn ws_auth_invalid_ticket_rejected() {
429 let headers = axum::http::HeaderMap::new();
430 let tickets = TicketStore::new();
431 let query = WsQuery {
432 ticket: Some("wst_invalid".to_string()),
433 };
434 assert!(!ws_authenticate(
435 &headers,
436 &query,
437 &tickets,
438 Some("test-key"),
439 None,
440 ));
441 }
442
443 #[test]
444 fn ws_auth_no_credentials_rejected() {
445 let headers = axum::http::HeaderMap::new();
446 let query = WsQuery { ticket: None };
447 let tickets = TicketStore::new();
448 assert!(!ws_authenticate(
449 &headers,
450 &query,
451 &tickets,
452 Some("test-key"),
453 None,
454 ));
455 }
456
457 #[test]
458 fn ws_auth_wrong_key_rejected() {
459 let mut headers = axum::http::HeaderMap::new();
460 headers.insert("x-api-key", "wrong-key".parse().unwrap());
461 let query = WsQuery { ticket: None };
462 let tickets = TicketStore::new();
463 assert!(!ws_authenticate(
464 &headers,
465 &query,
466 &tickets,
467 Some("test-key"),
468 None,
469 ));
470 }
471
472 #[test]
473 fn ws_auth_ticket_single_use() {
474 let headers = axum::http::HeaderMap::new();
475 let tickets = TicketStore::new();
476 let ticket = tickets.issue();
477 let query1 = WsQuery {
478 ticket: Some(ticket.clone()),
479 };
480 assert!(ws_authenticate(
481 &headers,
482 &query1,
483 &tickets,
484 Some("test-key"),
485 None,
486 ));
487 let query2 = WsQuery {
488 ticket: Some(ticket),
489 };
490 assert!(!ws_authenticate(
491 &headers,
492 &query2,
493 &tickets,
494 Some("test-key"),
495 None,
496 ));
497 }
498}