Skip to main content

roboticus_api/
ws.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use serde::Deserialize;
9use subtle::ConstantTimeEq;
10use tokio::sync::broadcast;
11use tokio::time::{Instant, interval};
12
13use crate::ws_ticket::TicketStore;
14
15#[derive(Clone)]
16pub struct EventBus {
17    tx: broadcast::Sender<String>,
18}
19
20impl EventBus {
21    pub fn new(capacity: usize) -> Self {
22        let (tx, _) = broadcast::channel(capacity);
23        Self { tx }
24    }
25
26    pub fn publish(&self, event: String) {
27        if let Err(e) = self.tx.send(event) {
28            tracing::debug!(error = %e, "EventBus publish: no active subscribers");
29        }
30    }
31
32    pub fn subscribe(&self) -> broadcast::Receiver<String> {
33        self.tx.subscribe()
34    }
35}
36
37#[derive(Deserialize)]
38struct WsQuery {
39    ticket: Option<String>,
40}
41
42/// Returns an axum GET route handler that upgrades the connection to WebSocket.
43///
44/// Authentication is handled inside this handler (not by the global API-key
45/// middleware) because the `/ws` route lives outside the authed router group.
46/// Accepts either:
47///   - `x-api-key` / `Authorization: Bearer …` header (programmatic clients)
48///   - `?ticket=wst_…` query param (short-lived, single-use ticket from `POST /api/ws-ticket`)
49pub fn ws_route(
50    bus: EventBus,
51    tickets: TicketStore,
52    api_key: Option<String>,
53) -> axum::routing::MethodRouter {
54    let api_key: Option<Arc<str>> = api_key.map(|k| Arc::from(k.as_str()));
55
56    let handler =
57        move |ws: WebSocketUpgrade,
58              headers: axum::http::HeaderMap,
59              axum::extract::ConnectInfo(peer_addr): axum::extract::ConnectInfo<SocketAddr>,
60              axum::extract::Query(query): axum::extract::Query<WsQuery>| {
61            let bus = bus.clone();
62            let tickets = tickets.clone();
63            let api_key = api_key.clone();
64            async move {
65                if !ws_authenticate(
66                    &headers,
67                    &query,
68                    &tickets,
69                    api_key.as_deref(),
70                    Some(peer_addr),
71                ) {
72                    return (StatusCode::UNAUTHORIZED, "Valid API key or ticket required")
73                        .into_response();
74                }
75                ws.on_upgrade(move |socket| handle_socket(socket, bus))
76                    .into_response()
77            }
78        };
79    axum::routing::get(handler)
80}
81
82/// Check WebSocket auth: header first, then ticket, then reject.
83fn ws_authenticate(
84    headers: &axum::http::HeaderMap,
85    query: &WsQuery,
86    tickets: &TicketStore,
87    api_key: Option<&str>,
88    peer_addr: Option<SocketAddr>,
89) -> bool {
90    // If no API key is configured, mirror HTTP auth middleware behavior:
91    // allow loopback-only access, reject remote clients.
92    let Some(expected) = api_key else {
93        return peer_addr.is_some_and(|addr| addr.ip().is_loopback());
94    };
95
96    // 1. Check x-api-key header
97    if let Some(val) = headers.get("x-api-key")
98        && let Ok(provided) = val.to_str()
99        && bool::from(provided.as_bytes().ct_eq(expected.as_bytes()))
100    {
101        return true;
102    }
103
104    // 2. Check Authorization: Bearer header
105    if let Some(val) = headers.get("authorization")
106        && let Ok(s) = val.to_str()
107        && let Some(token) = s.strip_prefix("Bearer ")
108        && bool::from(token.as_bytes().ct_eq(expected.as_bytes()))
109    {
110        return true;
111    }
112
113    // 3. Check ticket query param (single-use, short-lived)
114    if let Some(ref ticket) = query.ticket
115        && tickets.redeem(ticket)
116    {
117        return true;
118    }
119
120    false
121}
122
123const PING_INTERVAL: Duration = Duration::from_secs(30);
124const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
125
126async fn handle_socket(mut socket: WebSocket, bus: EventBus) {
127    let mut rx = bus.subscribe();
128
129    // Send a welcome message
130    let welcome = serde_json::json!({
131        "type": "connected",
132        "version": env!("CARGO_PKG_VERSION"),
133        "timestamp": chrono::Utc::now().to_rfc3339(),
134    });
135    if let Err(e) = socket.send(Message::Text(welcome.to_string().into())).await {
136        tracing::debug!(error = %e, "WebSocket welcome send failed");
137        return;
138    }
139
140    let mut ping_timer = interval(PING_INTERVAL);
141    ping_timer.tick().await; // consume the immediate first tick
142    let mut last_activity = Instant::now();
143
144    // Forward events from the bus to the WebSocket client
145    loop {
146        tokio::select! {
147            msg = rx.recv() => {
148                match msg {
149                    Ok(event) => {
150                        if socket.send(Message::Text(event.into())).await.is_err() {
151                            break; // client disconnected
152                        }
153                        last_activity = Instant::now();
154                    }
155                    Err(broadcast::error::RecvError::Lagged(n)) => {
156                        tracing::warn!(skipped = n, "WebSocket subscriber lagged, skipping lost events");
157                        continue;
158                    }
159                    Err(broadcast::error::RecvError::Closed) => break,
160                }
161            }
162            msg = socket.recv() => {
163                match msg {
164                    Some(Ok(Message::Text(text))) => {
165                        last_activity = Instant::now();
166                        // Limit inbound message size to prevent memory amplification
167                        if text.len() > 4096 {
168                            tracing::warn!(len = text.len(), "WebSocket message exceeds 4KiB limit, closing");
169                            break;
170                        }
171                        let resp = serde_json::json!({ "type": "ack" });
172                        if let Err(e) = socket.send(Message::Text(resp.to_string().into())).await {
173                            tracing::debug!(error = %e, "WebSocket ack send failed");
174                            break;
175                        }
176                    }
177                    Some(Ok(Message::Ping(data))) => {
178                        last_activity = Instant::now();
179                        if let Err(e) = socket.send(Message::Pong(data)).await {
180                            tracing::debug!(error = %e, "WebSocket pong send failed");
181                            break;
182                        }
183                    }
184                    Some(Ok(Message::Pong(_))) => {
185                        last_activity = Instant::now();
186                    }
187                    Some(Ok(Message::Close(_))) | None => break,
188                    _ => {}
189                }
190            }
191            _ = ping_timer.tick() => {
192                if last_activity.elapsed() > IDLE_TIMEOUT {
193                    tracing::info!("WebSocket idle timeout, closing connection");
194                    let _ = socket.send(Message::Close(None)).await;
195                    break;
196                }
197                if let Err(e) = socket.send(Message::Ping(vec![].into())).await {
198                    tracing::debug!(error = %e, "WebSocket ping send failed");
199                    break;
200                }
201            }
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[tokio::test]
211    async fn publish_and_receive() {
212        let bus = EventBus::new(16);
213        let mut rx = bus.subscribe();
214
215        bus.publish("hello".to_string());
216        let msg = rx.recv().await.unwrap();
217        assert_eq!(msg, "hello");
218    }
219
220    #[tokio::test]
221    async fn subscriber_receives_all_events() {
222        let bus = EventBus::new(16);
223        let mut rx = bus.subscribe();
224
225        bus.publish("event-1".to_string());
226        bus.publish("event-2".to_string());
227        bus.publish("event-3".to_string());
228
229        let m1 = rx.recv().await.unwrap();
230        let m2 = rx.recv().await.unwrap();
231        let m3 = rx.recv().await.unwrap();
232
233        assert_eq!(m1, "event-1");
234        assert_eq!(m2, "event-2");
235        assert_eq!(m3, "event-3");
236    }
237
238    #[tokio::test]
239    async fn multiple_subscribers() {
240        let bus = EventBus::new(16);
241        let mut rx1 = bus.subscribe();
242        let mut rx2 = bus.subscribe();
243
244        bus.publish("shared".to_string());
245
246        assert_eq!(rx1.recv().await.unwrap(), "shared");
247        assert_eq!(rx2.recv().await.unwrap(), "shared");
248    }
249
250    #[test]
251    fn publish_without_subscribers_does_not_panic() {
252        let bus = EventBus::new(4);
253        bus.publish("orphan".to_string());
254    }
255
256    #[test]
257    fn ws_route_returns_method_router() {
258        let bus = EventBus::new(256);
259        let tickets = TicketStore::new();
260        let _router = super::ws_route(bus, tickets, None);
261    }
262
263    #[tokio::test]
264    async fn event_bus_publish_subscribe() {
265        let bus = EventBus::new(16);
266        let mut rx = bus.subscribe();
267        bus.publish("hello".to_string());
268        let msg = rx.recv().await.unwrap();
269        assert_eq!(msg, "hello");
270    }
271
272    #[tokio::test]
273    async fn event_bus_multiple_subscribers() {
274        let bus = EventBus::new(16);
275        let mut rx1 = bus.subscribe();
276        let mut rx2 = bus.subscribe();
277        bus.publish("event1".to_string());
278        assert_eq!(rx1.recv().await.unwrap(), "event1");
279        assert_eq!(rx2.recv().await.unwrap(), "event1");
280    }
281
282    #[test]
283    fn event_bus_dropped_subscriber_does_not_block() {
284        let bus = EventBus::new(16);
285        let _rx = bus.subscribe();
286        drop(_rx);
287        bus.publish("should not block".to_string());
288    }
289
290    #[tokio::test]
291    async fn bus_clone_shares_channel() {
292        let bus1 = EventBus::new(16);
293        let bus2 = bus1.clone();
294        let mut rx = bus1.subscribe();
295
296        bus2.publish("from-clone".to_string());
297        let msg = rx.recv().await.unwrap();
298        assert_eq!(msg, "from-clone");
299    }
300
301    #[tokio::test]
302    async fn subscriber_after_publish_misses_earlier_events() {
303        let bus = EventBus::new(16);
304        bus.publish("before-subscribe".to_string());
305
306        let mut rx = bus.subscribe();
307        bus.publish("after-subscribe".to_string());
308
309        let msg = rx.recv().await.unwrap();
310        assert_eq!(msg, "after-subscribe");
311    }
312
313    #[test]
314    fn capacity_overflow_does_not_panic() {
315        let bus = EventBus::new(2);
316        let _rx = bus.subscribe();
317        for i in 0..10 {
318            bus.publish(format!("event-{i}"));
319        }
320    }
321
322    #[tokio::test]
323    async fn publish_json_event_roundtrip() {
324        let bus = EventBus::new(16);
325        let mut rx = bus.subscribe();
326        let event = serde_json::json!({"type": "inference", "model": "gpt-4", "tokens": 42});
327        bus.publish(event.to_string());
328        let msg = rx.recv().await.unwrap();
329        let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
330        assert_eq!(parsed["type"], "inference");
331        assert_eq!(parsed["tokens"], 42);
332    }
333
334    #[tokio::test]
335    async fn multiple_publishes_order_preserved() {
336        let bus = EventBus::new(64);
337        let mut rx = bus.subscribe();
338        for i in 0..50 {
339            bus.publish(format!("msg-{i}"));
340        }
341        for i in 0..50 {
342            let msg = rx.recv().await.unwrap();
343            assert_eq!(msg, format!("msg-{i}"));
344        }
345    }
346
347    #[tokio::test]
348    async fn concurrent_publishers() {
349        let bus = EventBus::new(256);
350        let mut rx = bus.subscribe();
351        let bus1 = bus.clone();
352        let bus2 = bus.clone();
353
354        let h1 = tokio::spawn(async move {
355            for i in 0..10 {
356                bus1.publish(format!("a-{i}"));
357            }
358        });
359        let h2 = tokio::spawn(async move {
360            for i in 0..10 {
361                bus2.publish(format!("b-{i}"));
362            }
363        });
364
365        h1.await.unwrap();
366        h2.await.unwrap();
367
368        let mut count = 0;
369        while let Ok(msg) =
370            tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await
371        {
372            msg.unwrap();
373            count += 1;
374        }
375        assert_eq!(count, 20);
376    }
377
378    #[test]
379    fn ws_route_builds_without_panic() {
380        let bus = EventBus::new(4);
381        let tickets = TicketStore::new();
382        let router = axum::Router::new().route("/ws", super::ws_route(bus, tickets, None));
383        let _app = router.into_make_service();
384    }
385
386    // ── WebSocket authentication tests ────────────────────────────
387
388    #[test]
389    fn ws_auth_no_key_configured_allows_loopback_only() {
390        let headers = axum::http::HeaderMap::new();
391        let query = WsQuery { ticket: None };
392        let tickets = TicketStore::new();
393        let loopback = Some("127.0.0.1:9000".parse::<SocketAddr>().unwrap());
394        let remote = Some("203.0.113.10:9000".parse::<SocketAddr>().unwrap());
395        assert!(ws_authenticate(&headers, &query, &tickets, None, loopback));
396        assert!(!ws_authenticate(&headers, &query, &tickets, None, remote));
397        assert!(!ws_authenticate(&headers, &query, &tickets, None, None));
398    }
399
400    #[test]
401    fn ws_auth_header_x_api_key() {
402        let mut headers = axum::http::HeaderMap::new();
403        headers.insert("x-api-key", "test-key".parse().unwrap());
404        let query = WsQuery { ticket: None };
405        let tickets = TicketStore::new();
406        assert!(ws_authenticate(
407            &headers,
408            &query,
409            &tickets,
410            Some("test-key"),
411            None,
412        ));
413    }
414
415    #[test]
416    fn ws_auth_header_bearer() {
417        let mut headers = axum::http::HeaderMap::new();
418        headers.insert("authorization", "Bearer test-key".parse().unwrap());
419        let query = WsQuery { ticket: None };
420        let tickets = TicketStore::new();
421        assert!(ws_authenticate(
422            &headers,
423            &query,
424            &tickets,
425            Some("test-key"),
426            None,
427        ));
428    }
429
430    #[test]
431    fn ws_auth_valid_ticket() {
432        let headers = axum::http::HeaderMap::new();
433        let tickets = TicketStore::new();
434        let ticket = tickets.issue();
435        let query = WsQuery {
436            ticket: Some(ticket),
437        };
438        assert!(ws_authenticate(
439            &headers,
440            &query,
441            &tickets,
442            Some("test-key"),
443            None,
444        ));
445    }
446
447    #[test]
448    fn ws_auth_invalid_ticket_rejected() {
449        let headers = axum::http::HeaderMap::new();
450        let tickets = TicketStore::new();
451        let query = WsQuery {
452            ticket: Some("wst_invalid".to_string()),
453        };
454        assert!(!ws_authenticate(
455            &headers,
456            &query,
457            &tickets,
458            Some("test-key"),
459            None,
460        ));
461    }
462
463    #[test]
464    fn ws_auth_no_credentials_rejected() {
465        let headers = axum::http::HeaderMap::new();
466        let query = WsQuery { ticket: None };
467        let tickets = TicketStore::new();
468        assert!(!ws_authenticate(
469            &headers,
470            &query,
471            &tickets,
472            Some("test-key"),
473            None,
474        ));
475    }
476
477    #[test]
478    fn ws_auth_wrong_key_rejected() {
479        let mut headers = axum::http::HeaderMap::new();
480        headers.insert("x-api-key", "wrong-key".parse().unwrap());
481        let query = WsQuery { ticket: None };
482        let tickets = TicketStore::new();
483        assert!(!ws_authenticate(
484            &headers,
485            &query,
486            &tickets,
487            Some("test-key"),
488            None,
489        ));
490    }
491
492    #[test]
493    fn ws_auth_ticket_single_use() {
494        let headers = axum::http::HeaderMap::new();
495        let tickets = TicketStore::new();
496        let ticket = tickets.issue();
497        let query1 = WsQuery {
498            ticket: Some(ticket.clone()),
499        };
500        assert!(ws_authenticate(
501            &headers,
502            &query1,
503            &tickets,
504            Some("test-key"),
505            None,
506        ));
507        let query2 = WsQuery {
508            ticket: Some(ticket),
509        };
510        assert!(!ws_authenticate(
511            &headers,
512            &query2,
513            &tickets,
514            Some("test-key"),
515            None,
516        ));
517    }
518}