1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use serde::Deserialize;
9use subtle::ConstantTimeEq;
10use tokio::sync::broadcast;
11use tokio::time::{Instant, interval};
12
13use crate::ws_ticket::TicketStore;
14
15#[derive(Clone)]
16pub struct EventBus {
17 tx: broadcast::Sender<String>,
18}
19
20impl EventBus {
21 pub fn new(capacity: usize) -> Self {
22 let (tx, _) = broadcast::channel(capacity);
23 Self { tx }
24 }
25
26 pub fn publish(&self, event: String) {
27 if let Err(e) = self.tx.send(event) {
28 tracing::debug!(error = %e, "EventBus publish: no active subscribers");
29 }
30 }
31
32 pub fn subscribe(&self) -> broadcast::Receiver<String> {
33 self.tx.subscribe()
34 }
35}
36
37#[derive(Deserialize)]
38struct WsQuery {
39 ticket: Option<String>,
40}
41
42pub fn ws_route(
50 bus: EventBus,
51 tickets: TicketStore,
52 api_key: Option<String>,
53) -> axum::routing::MethodRouter {
54 let api_key: Option<Arc<str>> = api_key.map(|k| Arc::from(k.as_str()));
55
56 let handler =
57 move |ws: WebSocketUpgrade,
58 headers: axum::http::HeaderMap,
59 axum::extract::ConnectInfo(peer_addr): axum::extract::ConnectInfo<SocketAddr>,
60 axum::extract::Query(query): axum::extract::Query<WsQuery>| {
61 let bus = bus.clone();
62 let tickets = tickets.clone();
63 let api_key = api_key.clone();
64 async move {
65 if !ws_authenticate(
66 &headers,
67 &query,
68 &tickets,
69 api_key.as_deref(),
70 Some(peer_addr),
71 ) {
72 return (StatusCode::UNAUTHORIZED, "Valid API key or ticket required")
73 .into_response();
74 }
75 ws.on_upgrade(move |socket| handle_socket(socket, bus))
76 .into_response()
77 }
78 };
79 axum::routing::get(handler)
80}
81
82fn ws_authenticate(
84 headers: &axum::http::HeaderMap,
85 query: &WsQuery,
86 tickets: &TicketStore,
87 api_key: Option<&str>,
88 peer_addr: Option<SocketAddr>,
89) -> bool {
90 let Some(expected) = api_key else {
93 return peer_addr.is_some_and(|addr| addr.ip().is_loopback());
94 };
95
96 if let Some(val) = headers.get("x-api-key")
98 && let Ok(provided) = val.to_str()
99 && bool::from(provided.as_bytes().ct_eq(expected.as_bytes()))
100 {
101 return true;
102 }
103
104 if let Some(val) = headers.get("authorization")
106 && let Ok(s) = val.to_str()
107 && let Some(token) = s.strip_prefix("Bearer ")
108 && bool::from(token.as_bytes().ct_eq(expected.as_bytes()))
109 {
110 return true;
111 }
112
113 if let Some(ref ticket) = query.ticket
115 && tickets.redeem(ticket)
116 {
117 return true;
118 }
119
120 false
121}
122
123const PING_INTERVAL: Duration = Duration::from_secs(30);
124const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
125
126async fn handle_socket(mut socket: WebSocket, bus: EventBus) {
127 let mut rx = bus.subscribe();
128
129 let welcome = serde_json::json!({
131 "type": "connected",
132 "version": env!("CARGO_PKG_VERSION"),
133 "timestamp": chrono::Utc::now().to_rfc3339(),
134 });
135 if let Err(e) = socket.send(Message::Text(welcome.to_string().into())).await {
136 tracing::debug!(error = %e, "WebSocket welcome send failed");
137 return;
138 }
139
140 let mut ping_timer = interval(PING_INTERVAL);
141 ping_timer.tick().await; let mut last_activity = Instant::now();
143
144 loop {
146 tokio::select! {
147 msg = rx.recv() => {
148 match msg {
149 Ok(event) => {
150 if socket.send(Message::Text(event.into())).await.is_err() {
151 break; }
153 last_activity = Instant::now();
154 }
155 Err(broadcast::error::RecvError::Lagged(n)) => {
156 tracing::warn!(skipped = n, "WebSocket subscriber lagged, skipping lost events");
157 continue;
158 }
159 Err(broadcast::error::RecvError::Closed) => break,
160 }
161 }
162 msg = socket.recv() => {
163 match msg {
164 Some(Ok(Message::Text(text))) => {
165 last_activity = Instant::now();
166 if text.len() > 4096 {
168 tracing::warn!(len = text.len(), "WebSocket message exceeds 4KiB limit, closing");
169 break;
170 }
171 let resp = serde_json::json!({ "type": "ack" });
172 if let Err(e) = socket.send(Message::Text(resp.to_string().into())).await {
173 tracing::debug!(error = %e, "WebSocket ack send failed");
174 break;
175 }
176 }
177 Some(Ok(Message::Ping(data))) => {
178 last_activity = Instant::now();
179 if let Err(e) = socket.send(Message::Pong(data)).await {
180 tracing::debug!(error = %e, "WebSocket pong send failed");
181 break;
182 }
183 }
184 Some(Ok(Message::Pong(_))) => {
185 last_activity = Instant::now();
186 }
187 Some(Ok(Message::Close(_))) | None => break,
188 _ => {}
189 }
190 }
191 _ = ping_timer.tick() => {
192 if last_activity.elapsed() > IDLE_TIMEOUT {
193 tracing::info!("WebSocket idle timeout, closing connection");
194 let _ = socket.send(Message::Close(None)).await;
195 break;
196 }
197 if let Err(e) = socket.send(Message::Ping(vec![].into())).await {
198 tracing::debug!(error = %e, "WebSocket ping send failed");
199 break;
200 }
201 }
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 async fn publish_and_receive() {
212 let bus = EventBus::new(16);
213 let mut rx = bus.subscribe();
214
215 bus.publish("hello".to_string());
216 let msg = rx.recv().await.unwrap();
217 assert_eq!(msg, "hello");
218 }
219
220 #[tokio::test]
221 async fn subscriber_receives_all_events() {
222 let bus = EventBus::new(16);
223 let mut rx = bus.subscribe();
224
225 bus.publish("event-1".to_string());
226 bus.publish("event-2".to_string());
227 bus.publish("event-3".to_string());
228
229 let m1 = rx.recv().await.unwrap();
230 let m2 = rx.recv().await.unwrap();
231 let m3 = rx.recv().await.unwrap();
232
233 assert_eq!(m1, "event-1");
234 assert_eq!(m2, "event-2");
235 assert_eq!(m3, "event-3");
236 }
237
238 #[tokio::test]
239 async fn multiple_subscribers() {
240 let bus = EventBus::new(16);
241 let mut rx1 = bus.subscribe();
242 let mut rx2 = bus.subscribe();
243
244 bus.publish("shared".to_string());
245
246 assert_eq!(rx1.recv().await.unwrap(), "shared");
247 assert_eq!(rx2.recv().await.unwrap(), "shared");
248 }
249
250 #[test]
251 fn publish_without_subscribers_does_not_panic() {
252 let bus = EventBus::new(4);
253 bus.publish("orphan".to_string());
254 }
255
256 #[test]
257 fn ws_route_returns_method_router() {
258 let bus = EventBus::new(256);
259 let tickets = TicketStore::new();
260 let _router = super::ws_route(bus, tickets, None);
261 }
262
263 #[tokio::test]
264 async fn event_bus_publish_subscribe() {
265 let bus = EventBus::new(16);
266 let mut rx = bus.subscribe();
267 bus.publish("hello".to_string());
268 let msg = rx.recv().await.unwrap();
269 assert_eq!(msg, "hello");
270 }
271
272 #[tokio::test]
273 async fn event_bus_multiple_subscribers() {
274 let bus = EventBus::new(16);
275 let mut rx1 = bus.subscribe();
276 let mut rx2 = bus.subscribe();
277 bus.publish("event1".to_string());
278 assert_eq!(rx1.recv().await.unwrap(), "event1");
279 assert_eq!(rx2.recv().await.unwrap(), "event1");
280 }
281
282 #[test]
283 fn event_bus_dropped_subscriber_does_not_block() {
284 let bus = EventBus::new(16);
285 let _rx = bus.subscribe();
286 drop(_rx);
287 bus.publish("should not block".to_string());
288 }
289
290 #[tokio::test]
291 async fn bus_clone_shares_channel() {
292 let bus1 = EventBus::new(16);
293 let bus2 = bus1.clone();
294 let mut rx = bus1.subscribe();
295
296 bus2.publish("from-clone".to_string());
297 let msg = rx.recv().await.unwrap();
298 assert_eq!(msg, "from-clone");
299 }
300
301 #[tokio::test]
302 async fn subscriber_after_publish_misses_earlier_events() {
303 let bus = EventBus::new(16);
304 bus.publish("before-subscribe".to_string());
305
306 let mut rx = bus.subscribe();
307 bus.publish("after-subscribe".to_string());
308
309 let msg = rx.recv().await.unwrap();
310 assert_eq!(msg, "after-subscribe");
311 }
312
313 #[test]
314 fn capacity_overflow_does_not_panic() {
315 let bus = EventBus::new(2);
316 let _rx = bus.subscribe();
317 for i in 0..10 {
318 bus.publish(format!("event-{i}"));
319 }
320 }
321
322 #[tokio::test]
323 async fn publish_json_event_roundtrip() {
324 let bus = EventBus::new(16);
325 let mut rx = bus.subscribe();
326 let event = serde_json::json!({"type": "inference", "model": "gpt-4", "tokens": 42});
327 bus.publish(event.to_string());
328 let msg = rx.recv().await.unwrap();
329 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
330 assert_eq!(parsed["type"], "inference");
331 assert_eq!(parsed["tokens"], 42);
332 }
333
334 #[tokio::test]
335 async fn multiple_publishes_order_preserved() {
336 let bus = EventBus::new(64);
337 let mut rx = bus.subscribe();
338 for i in 0..50 {
339 bus.publish(format!("msg-{i}"));
340 }
341 for i in 0..50 {
342 let msg = rx.recv().await.unwrap();
343 assert_eq!(msg, format!("msg-{i}"));
344 }
345 }
346
347 #[tokio::test]
348 async fn concurrent_publishers() {
349 let bus = EventBus::new(256);
350 let mut rx = bus.subscribe();
351 let bus1 = bus.clone();
352 let bus2 = bus.clone();
353
354 let h1 = tokio::spawn(async move {
355 for i in 0..10 {
356 bus1.publish(format!("a-{i}"));
357 }
358 });
359 let h2 = tokio::spawn(async move {
360 for i in 0..10 {
361 bus2.publish(format!("b-{i}"));
362 }
363 });
364
365 h1.await.unwrap();
366 h2.await.unwrap();
367
368 let mut count = 0;
369 while let Ok(msg) =
370 tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await
371 {
372 msg.unwrap();
373 count += 1;
374 }
375 assert_eq!(count, 20);
376 }
377
378 #[test]
379 fn ws_route_builds_without_panic() {
380 let bus = EventBus::new(4);
381 let tickets = TicketStore::new();
382 let router = axum::Router::new().route("/ws", super::ws_route(bus, tickets, None));
383 let _app = router.into_make_service();
384 }
385
386 #[test]
389 fn ws_auth_no_key_configured_allows_loopback_only() {
390 let headers = axum::http::HeaderMap::new();
391 let query = WsQuery { ticket: None };
392 let tickets = TicketStore::new();
393 let loopback = Some("127.0.0.1:9000".parse::<SocketAddr>().unwrap());
394 let remote = Some("203.0.113.10:9000".parse::<SocketAddr>().unwrap());
395 assert!(ws_authenticate(&headers, &query, &tickets, None, loopback));
396 assert!(!ws_authenticate(&headers, &query, &tickets, None, remote));
397 assert!(!ws_authenticate(&headers, &query, &tickets, None, None));
398 }
399
400 #[test]
401 fn ws_auth_header_x_api_key() {
402 let mut headers = axum::http::HeaderMap::new();
403 headers.insert("x-api-key", "test-key".parse().unwrap());
404 let query = WsQuery { ticket: None };
405 let tickets = TicketStore::new();
406 assert!(ws_authenticate(
407 &headers,
408 &query,
409 &tickets,
410 Some("test-key"),
411 None,
412 ));
413 }
414
415 #[test]
416 fn ws_auth_header_bearer() {
417 let mut headers = axum::http::HeaderMap::new();
418 headers.insert("authorization", "Bearer test-key".parse().unwrap());
419 let query = WsQuery { ticket: None };
420 let tickets = TicketStore::new();
421 assert!(ws_authenticate(
422 &headers,
423 &query,
424 &tickets,
425 Some("test-key"),
426 None,
427 ));
428 }
429
430 #[test]
431 fn ws_auth_valid_ticket() {
432 let headers = axum::http::HeaderMap::new();
433 let tickets = TicketStore::new();
434 let ticket = tickets.issue();
435 let query = WsQuery {
436 ticket: Some(ticket),
437 };
438 assert!(ws_authenticate(
439 &headers,
440 &query,
441 &tickets,
442 Some("test-key"),
443 None,
444 ));
445 }
446
447 #[test]
448 fn ws_auth_invalid_ticket_rejected() {
449 let headers = axum::http::HeaderMap::new();
450 let tickets = TicketStore::new();
451 let query = WsQuery {
452 ticket: Some("wst_invalid".to_string()),
453 };
454 assert!(!ws_authenticate(
455 &headers,
456 &query,
457 &tickets,
458 Some("test-key"),
459 None,
460 ));
461 }
462
463 #[test]
464 fn ws_auth_no_credentials_rejected() {
465 let headers = axum::http::HeaderMap::new();
466 let query = WsQuery { ticket: None };
467 let tickets = TicketStore::new();
468 assert!(!ws_authenticate(
469 &headers,
470 &query,
471 &tickets,
472 Some("test-key"),
473 None,
474 ));
475 }
476
477 #[test]
478 fn ws_auth_wrong_key_rejected() {
479 let mut headers = axum::http::HeaderMap::new();
480 headers.insert("x-api-key", "wrong-key".parse().unwrap());
481 let query = WsQuery { ticket: None };
482 let tickets = TicketStore::new();
483 assert!(!ws_authenticate(
484 &headers,
485 &query,
486 &tickets,
487 Some("test-key"),
488 None,
489 ));
490 }
491
492 #[test]
493 fn ws_auth_ticket_single_use() {
494 let headers = axum::http::HeaderMap::new();
495 let tickets = TicketStore::new();
496 let ticket = tickets.issue();
497 let query1 = WsQuery {
498 ticket: Some(ticket.clone()),
499 };
500 assert!(ws_authenticate(
501 &headers,
502 &query1,
503 &tickets,
504 Some("test-key"),
505 None,
506 ));
507 let query2 = WsQuery {
508 ticket: Some(ticket),
509 };
510 assert!(!ws_authenticate(
511 &headers,
512 &query2,
513 &tickets,
514 Some("test-key"),
515 None,
516 ));
517 }
518}