Skip to main content

roboticus_api/
ws.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::http::StatusCode;
7use axum::response::IntoResponse;
8use serde::Deserialize;
9use subtle::ConstantTimeEq;
10use tokio::sync::broadcast;
11use tokio::time::{Instant, interval};
12
13use crate::ws_ticket::TicketStore;
14
15pub use roboticus_pipeline::event_bus::EventBus;
16
17#[derive(Deserialize)]
18struct WsQuery {
19    ticket: Option<String>,
20}
21
22/// Returns an axum GET route handler that upgrades the connection to WebSocket.
23///
24/// Authentication is handled inside this handler (not by the global API-key
25/// middleware) because the `/ws` route lives outside the authed router group.
26/// Accepts either:
27///   - `x-api-key` / `Authorization: Bearer …` header (programmatic clients)
28///   - `?ticket=wst_…` query param (short-lived, single-use ticket from `POST /api/ws-ticket`)
29pub fn ws_route(
30    bus: EventBus,
31    tickets: TicketStore,
32    api_key: Option<String>,
33) -> axum::routing::MethodRouter {
34    let api_key: Option<Arc<str>> = api_key.map(|k| Arc::from(k.as_str()));
35
36    let handler =
37        move |ws: WebSocketUpgrade,
38              headers: axum::http::HeaderMap,
39              axum::extract::ConnectInfo(peer_addr): axum::extract::ConnectInfo<SocketAddr>,
40              axum::extract::Query(query): axum::extract::Query<WsQuery>| {
41            let bus = bus.clone();
42            let tickets = tickets.clone();
43            let api_key = api_key.clone();
44            async move {
45                if !ws_authenticate(
46                    &headers,
47                    &query,
48                    &tickets,
49                    api_key.as_deref(),
50                    Some(peer_addr),
51                ) {
52                    return (StatusCode::UNAUTHORIZED, "Valid API key or ticket required")
53                        .into_response();
54                }
55                ws.on_upgrade(move |socket| handle_socket(socket, bus))
56                    .into_response()
57            }
58        };
59    axum::routing::get(handler)
60}
61
62/// Check WebSocket auth: header first, then ticket, then reject.
63fn ws_authenticate(
64    headers: &axum::http::HeaderMap,
65    query: &WsQuery,
66    tickets: &TicketStore,
67    api_key: Option<&str>,
68    peer_addr: Option<SocketAddr>,
69) -> bool {
70    // If no API key is configured, mirror HTTP auth middleware behavior:
71    // allow loopback-only access, reject remote clients.
72    let Some(expected) = api_key else {
73        return peer_addr.is_some_and(|addr| addr.ip().is_loopback());
74    };
75
76    // 1. Check x-api-key header
77    if let Some(val) = headers.get("x-api-key")
78        && let Ok(provided) = val.to_str()
79        && bool::from(provided.as_bytes().ct_eq(expected.as_bytes()))
80    {
81        return true;
82    }
83
84    // 2. Check Authorization: Bearer header
85    if let Some(val) = headers.get("authorization")
86        && let Ok(s) = val.to_str()
87        && let Some(token) = s.strip_prefix("Bearer ")
88        && bool::from(token.as_bytes().ct_eq(expected.as_bytes()))
89    {
90        return true;
91    }
92
93    // 3. Check ticket query param (single-use, short-lived)
94    if let Some(ref ticket) = query.ticket
95        && tickets.redeem(ticket)
96    {
97        return true;
98    }
99
100    false
101}
102
103const PING_INTERVAL: Duration = Duration::from_secs(30);
104const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
105
106async fn handle_socket(mut socket: WebSocket, bus: EventBus) {
107    let mut rx = bus.subscribe();
108
109    // Send a welcome message
110    let welcome = serde_json::json!({
111        "type": "connected",
112        "version": env!("CARGO_PKG_VERSION"),
113        "timestamp": chrono::Utc::now().to_rfc3339(),
114    });
115    if let Err(e) = socket.send(Message::Text(welcome.to_string().into())).await {
116        tracing::debug!(error = %e, "WebSocket welcome send failed");
117        return;
118    }
119
120    let mut ping_timer = interval(PING_INTERVAL);
121    ping_timer.tick().await; // consume the immediate first tick
122    let mut last_activity = Instant::now();
123
124    // Forward events from the bus to the WebSocket client
125    loop {
126        tokio::select! {
127            msg = rx.recv() => {
128                match msg {
129                    Ok(event) => {
130                        if socket.send(Message::Text(event.into())).await.is_err() {
131                            break; // client disconnected
132                        }
133                        last_activity = Instant::now();
134                    }
135                    Err(broadcast::error::RecvError::Lagged(n)) => {
136                        tracing::warn!(skipped = n, "WebSocket subscriber lagged, skipping lost events");
137                        continue;
138                    }
139                    Err(broadcast::error::RecvError::Closed) => break,
140                }
141            }
142            msg = socket.recv() => {
143                match msg {
144                    Some(Ok(Message::Text(text))) => {
145                        last_activity = Instant::now();
146                        // Limit inbound message size to prevent memory amplification
147                        if text.len() > 4096 {
148                            tracing::warn!(len = text.len(), "WebSocket message exceeds 4KiB limit, closing");
149                            break;
150                        }
151                        let resp = serde_json::json!({ "type": "ack" });
152                        if let Err(e) = socket.send(Message::Text(resp.to_string().into())).await {
153                            tracing::debug!(error = %e, "WebSocket ack send failed");
154                            break;
155                        }
156                    }
157                    Some(Ok(Message::Ping(data))) => {
158                        last_activity = Instant::now();
159                        if let Err(e) = socket.send(Message::Pong(data)).await {
160                            tracing::debug!(error = %e, "WebSocket pong send failed");
161                            break;
162                        }
163                    }
164                    Some(Ok(Message::Pong(_))) => {
165                        last_activity = Instant::now();
166                    }
167                    Some(Ok(Message::Close(_))) | None => break,
168                    _ => {}
169                }
170            }
171            _ = ping_timer.tick() => {
172                if last_activity.elapsed() > IDLE_TIMEOUT {
173                    tracing::info!("WebSocket idle timeout, closing connection");
174                    let _ = socket.send(Message::Close(None)).await;
175                    break;
176                }
177                if let Err(e) = socket.send(Message::Ping(vec![].into())).await {
178                    tracing::debug!(error = %e, "WebSocket ping send failed");
179                    break;
180                }
181            }
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[tokio::test]
191    async fn publish_and_receive() {
192        let bus = EventBus::new(16);
193        let mut rx = bus.subscribe();
194
195        bus.publish("hello".to_string());
196        let msg = rx.recv().await.unwrap();
197        assert_eq!(msg, "hello");
198    }
199
200    #[tokio::test]
201    async fn subscriber_receives_all_events() {
202        let bus = EventBus::new(16);
203        let mut rx = bus.subscribe();
204
205        bus.publish("event-1".to_string());
206        bus.publish("event-2".to_string());
207        bus.publish("event-3".to_string());
208
209        let m1 = rx.recv().await.unwrap();
210        let m2 = rx.recv().await.unwrap();
211        let m3 = rx.recv().await.unwrap();
212
213        assert_eq!(m1, "event-1");
214        assert_eq!(m2, "event-2");
215        assert_eq!(m3, "event-3");
216    }
217
218    #[tokio::test]
219    async fn multiple_subscribers() {
220        let bus = EventBus::new(16);
221        let mut rx1 = bus.subscribe();
222        let mut rx2 = bus.subscribe();
223
224        bus.publish("shared".to_string());
225
226        assert_eq!(rx1.recv().await.unwrap(), "shared");
227        assert_eq!(rx2.recv().await.unwrap(), "shared");
228    }
229
230    #[test]
231    fn publish_without_subscribers_does_not_panic() {
232        let bus = EventBus::new(4);
233        bus.publish("orphan".to_string());
234    }
235
236    #[test]
237    fn ws_route_returns_method_router() {
238        let bus = EventBus::new(256);
239        let tickets = TicketStore::new();
240        let _router = super::ws_route(bus, tickets, None);
241    }
242
243    #[tokio::test]
244    async fn event_bus_publish_subscribe() {
245        let bus = EventBus::new(16);
246        let mut rx = bus.subscribe();
247        bus.publish("hello".to_string());
248        let msg = rx.recv().await.unwrap();
249        assert_eq!(msg, "hello");
250    }
251
252    #[tokio::test]
253    async fn event_bus_multiple_subscribers() {
254        let bus = EventBus::new(16);
255        let mut rx1 = bus.subscribe();
256        let mut rx2 = bus.subscribe();
257        bus.publish("event1".to_string());
258        assert_eq!(rx1.recv().await.unwrap(), "event1");
259        assert_eq!(rx2.recv().await.unwrap(), "event1");
260    }
261
262    #[test]
263    fn event_bus_dropped_subscriber_does_not_block() {
264        let bus = EventBus::new(16);
265        let _rx = bus.subscribe();
266        drop(_rx);
267        bus.publish("should not block".to_string());
268    }
269
270    #[tokio::test]
271    async fn bus_clone_shares_channel() {
272        let bus1 = EventBus::new(16);
273        let bus2 = bus1.clone();
274        let mut rx = bus1.subscribe();
275
276        bus2.publish("from-clone".to_string());
277        let msg = rx.recv().await.unwrap();
278        assert_eq!(msg, "from-clone");
279    }
280
281    #[tokio::test]
282    async fn subscriber_after_publish_misses_earlier_events() {
283        let bus = EventBus::new(16);
284        bus.publish("before-subscribe".to_string());
285
286        let mut rx = bus.subscribe();
287        bus.publish("after-subscribe".to_string());
288
289        let msg = rx.recv().await.unwrap();
290        assert_eq!(msg, "after-subscribe");
291    }
292
293    #[test]
294    fn capacity_overflow_does_not_panic() {
295        let bus = EventBus::new(2);
296        let _rx = bus.subscribe();
297        for i in 0..10 {
298            bus.publish(format!("event-{i}"));
299        }
300    }
301
302    #[tokio::test]
303    async fn publish_json_event_roundtrip() {
304        let bus = EventBus::new(16);
305        let mut rx = bus.subscribe();
306        let event = serde_json::json!({"type": "inference", "model": "gpt-4", "tokens": 42});
307        bus.publish(event.to_string());
308        let msg = rx.recv().await.unwrap();
309        let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
310        assert_eq!(parsed["type"], "inference");
311        assert_eq!(parsed["tokens"], 42);
312    }
313
314    #[tokio::test]
315    async fn multiple_publishes_order_preserved() {
316        let bus = EventBus::new(64);
317        let mut rx = bus.subscribe();
318        for i in 0..50 {
319            bus.publish(format!("msg-{i}"));
320        }
321        for i in 0..50 {
322            let msg = rx.recv().await.unwrap();
323            assert_eq!(msg, format!("msg-{i}"));
324        }
325    }
326
327    #[tokio::test]
328    async fn concurrent_publishers() {
329        let bus = EventBus::new(256);
330        let mut rx = bus.subscribe();
331        let bus1 = bus.clone();
332        let bus2 = bus.clone();
333
334        let h1 = tokio::spawn(async move {
335            for i in 0..10 {
336                bus1.publish(format!("a-{i}"));
337            }
338        });
339        let h2 = tokio::spawn(async move {
340            for i in 0..10 {
341                bus2.publish(format!("b-{i}"));
342            }
343        });
344
345        h1.await.unwrap();
346        h2.await.unwrap();
347
348        let mut count = 0;
349        while let Ok(msg) =
350            tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await
351        {
352            msg.unwrap();
353            count += 1;
354        }
355        assert_eq!(count, 20);
356    }
357
358    #[test]
359    fn ws_route_builds_without_panic() {
360        let bus = EventBus::new(4);
361        let tickets = TicketStore::new();
362        let router = axum::Router::new().route("/ws", super::ws_route(bus, tickets, None));
363        let _app = router.into_make_service();
364    }
365
366    // ── WebSocket authentication tests ────────────────────────────
367
368    #[test]
369    fn ws_auth_no_key_configured_allows_loopback_only() {
370        let headers = axum::http::HeaderMap::new();
371        let query = WsQuery { ticket: None };
372        let tickets = TicketStore::new();
373        let loopback = Some("127.0.0.1:9000".parse::<SocketAddr>().unwrap());
374        let remote = Some("203.0.113.10:9000".parse::<SocketAddr>().unwrap());
375        assert!(ws_authenticate(&headers, &query, &tickets, None, loopback));
376        assert!(!ws_authenticate(&headers, &query, &tickets, None, remote));
377        assert!(!ws_authenticate(&headers, &query, &tickets, None, None));
378    }
379
380    #[test]
381    fn ws_auth_header_x_api_key() {
382        let mut headers = axum::http::HeaderMap::new();
383        headers.insert("x-api-key", "test-key".parse().unwrap());
384        let query = WsQuery { ticket: None };
385        let tickets = TicketStore::new();
386        assert!(ws_authenticate(
387            &headers,
388            &query,
389            &tickets,
390            Some("test-key"),
391            None,
392        ));
393    }
394
395    #[test]
396    fn ws_auth_header_bearer() {
397        let mut headers = axum::http::HeaderMap::new();
398        headers.insert("authorization", "Bearer test-key".parse().unwrap());
399        let query = WsQuery { ticket: None };
400        let tickets = TicketStore::new();
401        assert!(ws_authenticate(
402            &headers,
403            &query,
404            &tickets,
405            Some("test-key"),
406            None,
407        ));
408    }
409
410    #[test]
411    fn ws_auth_valid_ticket() {
412        let headers = axum::http::HeaderMap::new();
413        let tickets = TicketStore::new();
414        let ticket = tickets.issue();
415        let query = WsQuery {
416            ticket: Some(ticket),
417        };
418        assert!(ws_authenticate(
419            &headers,
420            &query,
421            &tickets,
422            Some("test-key"),
423            None,
424        ));
425    }
426
427    #[test]
428    fn ws_auth_invalid_ticket_rejected() {
429        let headers = axum::http::HeaderMap::new();
430        let tickets = TicketStore::new();
431        let query = WsQuery {
432            ticket: Some("wst_invalid".to_string()),
433        };
434        assert!(!ws_authenticate(
435            &headers,
436            &query,
437            &tickets,
438            Some("test-key"),
439            None,
440        ));
441    }
442
443    #[test]
444    fn ws_auth_no_credentials_rejected() {
445        let headers = axum::http::HeaderMap::new();
446        let query = WsQuery { ticket: None };
447        let tickets = TicketStore::new();
448        assert!(!ws_authenticate(
449            &headers,
450            &query,
451            &tickets,
452            Some("test-key"),
453            None,
454        ));
455    }
456
457    #[test]
458    fn ws_auth_wrong_key_rejected() {
459        let mut headers = axum::http::HeaderMap::new();
460        headers.insert("x-api-key", "wrong-key".parse().unwrap());
461        let query = WsQuery { ticket: None };
462        let tickets = TicketStore::new();
463        assert!(!ws_authenticate(
464            &headers,
465            &query,
466            &tickets,
467            Some("test-key"),
468            None,
469        ));
470    }
471
472    #[test]
473    fn ws_auth_ticket_single_use() {
474        let headers = axum::http::HeaderMap::new();
475        let tickets = TicketStore::new();
476        let ticket = tickets.issue();
477        let query1 = WsQuery {
478            ticket: Some(ticket.clone()),
479        };
480        assert!(ws_authenticate(
481            &headers,
482            &query1,
483            &tickets,
484            Some("test-key"),
485            None,
486        ));
487        let query2 = WsQuery {
488            ticket: Some(ticket),
489        };
490        assert!(!ws_authenticate(
491            &headers,
492            &query2,
493            &tickets,
494            Some("test-key"),
495            None,
496        ));
497    }
498}