Skip to main content

arcly_http/realtime/
ws.rs

1//! WebSocket boundary: upgrade, per-socket read/write pumps, event dispatch.
2//!
3//! This is the *only* module that touches `axum::extract::ws` — the analogue of
4//! [`crate::web::boundary`] for the real-time layer. Everything above it speaks
5//! arcly types ([`WsClient`], [`WsMessage`], [`GatewayRuntime`]).
6//!
7//! ## Per-connection model (no hot-path locks)
8//!
9//! ```text
10//!            ┌─────────────────── handle_socket task ───────────────────┐
11//!  socket ──>│ reader: stream.next() ─> dispatch(event) ─> handler fut  │
12//!            │ writer: rx.recv()     ─> sink.send(frame)                │
13//!            └───────────────────────────────────────────────────────────┘
14//! ```
15//! The reader and writer run as independent halves of the split socket. Inbound
16//! frames are parsed and routed through the gateway's `dispatch` table (an
17//! immutable `&HashMap` — lock-free read). Outbound frames are produced by any
18//! task via the registry's sharded channels and drained by this socket's writer.
19
20use std::sync::Arc;
21
22use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
23use axum::http::HeaderMap;
24use axum::routing::{get, MethodRouter};
25use futures::{SinkExt, StreamExt};
26use tokio::sync::{mpsc, oneshot};
27
28use crate::core::engine::FrozenDiContainer;
29use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
30use crate::realtime::gateway::GatewayRuntime;
31use crate::web::context::Claims;
32
33/// Per-gateway runtime tuning, sourced from `LaunchConfig` at mount time.
34#[derive(Clone, Copy, Debug)]
35pub struct WsTuning {
36    /// Outbound queue depth per socket — the slow-client memory ceiling.
37    pub outbound_buffer: usize,
38    /// Hard cap on concurrent sockets across all gateways (`0` = unlimited);
39    /// beyond it upgrades are refused with `503` before any socket exists.
40    pub max_connections: usize,
41    /// Server→client Ping cadence (`ZERO` disables). Pings provoke pongs,
42    /// which feed the idle sweeper's `last_seen`.
43    pub ping_interval: std::time::Duration,
44}
45
46/// Build the axum `MethodRouter` that upgrades HTTP→WebSocket for one gateway.
47///
48/// If a `JwtService` has been provided in the DI container, the
49/// `Authorization: Bearer <token>` header is decoded during the WebSocket
50/// handshake and the resulting claims are threaded through to every `WsClient`
51/// so gateway handlers can call `client.claims()` for auth decisions.
52pub fn ws_route(
53    runtime: &'static GatewayRuntime,
54    registry: &'static ConnectionRegistry,
55    container: &'static FrozenDiContainer,
56    tuning: WsTuning,
57) -> MethodRouter {
58    let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
59        // Admission control happens BEFORE the upgrade — past the cap no
60        // socket, queue, or registry entry is ever created.
61        if tuning.max_connections > 0 && registry.connection_count() >= tuning.max_connections {
62            metrics::counter!("ws_upgrades_refused_total").increment(1);
63            return axum::http::Response::builder()
64                .status(503)
65                .header("retry-after", "5")
66                .body(axum::body::Body::from("websocket capacity reached"))
67                .expect("static refusal");
68        }
69        // The SAME unified extraction as the HTTP boundary (pipeline):
70        // trace + tenant + credentials in one pass. The handshake
71        // authenticates once; gateway handlers see claims AND the resolved
72        // tenant, and the connection inherits the caller's trace identity.
73        let provenance = crate::pipeline::Provenance::from_headers(&headers, container).await;
74        tracing::debug!(
75            trace_id = %crate::observability::lean_telemetry::hex_encode(&provenance.trace.trace_id),
76            tenant = provenance.tenant.as_deref().map(|t| t.id.as_str()).unwrap_or(""),
77            "WS handshake provenance"
78        );
79        ws.on_upgrade(move |socket| {
80            handle_socket(
81                socket,
82                runtime,
83                registry,
84                provenance.claims,
85                provenance.tenant,
86                tuning,
87            )
88        })
89    };
90    get(handler)
91}
92
93/// Drive one upgraded socket to completion: register, pump, dispatch, drain.
94async fn handle_socket(
95    socket: WebSocket,
96    runtime: &'static GatewayRuntime,
97    registry: &'static ConnectionRegistry,
98    claims: Option<Arc<Claims>>,
99    tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
100    tuning: WsTuning,
101) {
102    let (mut sink, mut stream) = socket.split();
103
104    // Outbound queue: any task enqueues, this socket's writer drains.
105    // **Bounded** — the depth is the per-socket memory ceiling; a client
106    // that can't drain it gets evicted by the registry, never buffered
107    // without limit.
108    let (tx, mut rx) = mpsc::channel::<WsMessage>(tuning.outbound_buffer.max(1));
109    let id = registry.register(tx, claims.clone());
110    let client = WsClient::__new(id, registry, claims, tenant);
111
112    // One-shot signal: when the writer exits for *any* reason (peer error,
113    // server-initiated Close, or channel closed), the reader is unblocked so it
114    // stops polling the stream and runs on_disconnect + unregister. Without this,
115    // a server-initiated close would leave the reader blocked on stream.next()
116    // indefinitely if the peer never sends a Close echo.
117    let (close_tx, mut close_rx) = oneshot::channel::<()>();
118
119    // Writer half — owns the sink; exits when the queue closes or the peer
120    // dies. A periodic Ping (when configured) keeps NATs/proxies open and
121    // provokes pongs that feed the idle sweeper's `last_seen`.
122    let ping_every = tuning.ping_interval;
123    let writer = tokio::spawn(async move {
124        let mut ping = (!ping_every.is_zero()).then(|| {
125            let mut t = tokio::time::interval(ping_every);
126            t.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
127            t
128        });
129        loop {
130            let msg = if let Some(t) = ping.as_mut() {
131                tokio::select! {
132                    m = rx.recv() => m,
133                    _ = t.tick() => {
134                        if sink.send(Message::Ping(Vec::new())).await.is_err() {
135                            break;
136                        }
137                        continue;
138                    }
139                }
140            } else {
141                rx.recv().await
142            };
143            let Some(msg) = msg else { break };
144            let frame = match msg {
145                WsMessage::Text(arc) => Message::Text(String::from(arc.as_ref())),
146                WsMessage::Ping => Message::Ping(Vec::new()),
147                WsMessage::Close => {
148                    // Send close frame then exit immediately — do NOT loop back
149                    // to rx.recv(), which would keep the sink open indefinitely.
150                    let _ = sink.send(Message::Close(None)).await;
151                    break;
152                }
153            };
154            if sink.send(frame).await.is_err() {
155                break;
156            }
157        }
158        // Dropping close_tx signals the reader regardless of why we exited.
159        drop(close_tx);
160    });
161
162    (runtime.on_connect)(client.clone()).await;
163
164    // Reader half — routes inbound frames to subscribed handlers.
165    // Also watches for the writer-exit signal so a server-initiated close
166    // (WsMessage::Close enqueued by a handler) terminates the reader promptly.
167    loop {
168        tokio::select! {
169            biased;
170            // Writer exited (server-initiated close or peer write error).
171            _ = &mut close_rx => break,
172            frame = stream.next() => match frame {
173                None => break,
174                Some(Err(_)) => break,
175                Some(Ok(frame)) => {
176                    // Any inbound frame (including pongs from our pings)
177                    // proves the link is alive for the idle sweeper.
178                    registry.touch(id);
179                    match frame {
180                        Message::Text(text) => dispatch_event(runtime, &client, &text).await,
181                        Message::Binary(_) => { /* binary multiplexing not enabled */ }
182                        Message::Close(_) => break,
183                        Message::Ping(_) | Message::Pong(_) => { /* axum auto-replies to pings */ }
184                    }
185                }
186            }
187        }
188    }
189
190    (runtime.on_disconnect)(client.clone()).await;
191    registry.unregister(id);
192    writer.abort();
193}
194
195/// Parse one `{ "event": ..., "data": ... }` envelope and invoke its handler.
196/// Unknown events and malformed frames are ignored (a hostile client can't
197/// crash the dispatcher).
198async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
199    let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
200        return;
201    };
202    let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
203        return;
204    };
205    let Some(handler) = runtime.handler(event) else {
206        return;
207    };
208
209    let data = value
210        .get("data")
211        .cloned()
212        .unwrap_or(serde_json::Value::Null);
213    let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
214
215    metrics::counter!("ws_messages_in_total").increment(1);
216    // Handler errors stay at the transport edge — gateways own their own
217    // error-to-client signalling — but they are counted and logged so a
218    // misbehaving event handler is visible on dashboards.
219    if let Err(e) = handler(client.clone(), data_str).await {
220        metrics::counter!("ws_handler_errors_total").increment(1);
221        tracing::debug!(conn = client.id(), event, error = %e, "gateway handler error");
222    }
223}