Skip to main content

arcly_http/realtime/
ws.rs

1//! WebSocket boundary: upgrade, per-socket read/write pumps, event dispatch.
2//!
3//! This is the *only* module that touches `axum::extract::ws` — the analogue of
4//! [`crate::web::boundary`] for the real-time layer. Everything above it speaks
5//! arcly types ([`WsClient`], [`WsMessage`], [`GatewayRuntime`]).
6//!
7//! ## Per-connection model (no hot-path locks)
8//!
9//! ```text
10//!            ┌─────────────────── handle_socket task ───────────────────┐
11//!  socket ──>│ reader: stream.next() ─> dispatch(event) ─> handler fut  │
12//!            │ writer: rx.recv()     ─> sink.send(frame)                │
13//!            └───────────────────────────────────────────────────────────┘
14//! ```
15//! The reader and writer run as independent halves of the split socket. Inbound
16//! frames are parsed and routed through the gateway's `dispatch` table (an
17//! immutable `&HashMap` — lock-free read). Outbound frames are produced by any
18//! task via the registry's sharded channels and drained by this socket's writer.
19
20use std::sync::Arc;
21
22use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
23use axum::http::HeaderMap;
24use axum::routing::{get, MethodRouter};
25use futures::{SinkExt, StreamExt};
26use tokio::sync::{mpsc, oneshot};
27
28use crate::core::engine::FrozenDiContainer;
29use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
30use crate::realtime::gateway::GatewayRuntime;
31use crate::web::context::Claims;
32
33/// Per-gateway runtime tuning, sourced from `LaunchConfig` at mount time.
34#[derive(Clone, Copy, Debug)]
35#[non_exhaustive]
36pub struct WsTuning {
37    /// Outbound queue depth per socket — the slow-client memory ceiling.
38    pub outbound_buffer: usize,
39    /// Hard cap on concurrent sockets across all gateways (`0` = unlimited);
40    /// beyond it upgrades are refused with `503` before any socket exists.
41    pub max_connections: usize,
42    /// Server→client Ping cadence (`ZERO` disables). Pings provoke pongs,
43    /// which feed the idle sweeper's `last_seen`.
44    pub ping_interval: std::time::Duration,
45}
46
47/// Build the axum `MethodRouter` that upgrades HTTP→WebSocket for one gateway.
48///
49/// If a `JwtService` has been provided in the DI container, the
50/// `Authorization: Bearer <token>` header is decoded during the WebSocket
51/// handshake and the resulting claims are threaded through to every `WsClient`
52/// so gateway handlers can call `client.claims()` for auth decisions.
53pub fn ws_route(
54    runtime: &'static GatewayRuntime,
55    registry: &'static ConnectionRegistry,
56    container: &'static FrozenDiContainer,
57    tuning: WsTuning,
58) -> MethodRouter {
59    let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
60        // Admission control happens BEFORE the upgrade — past the cap no
61        // socket, queue, or registry entry is ever created.
62        if tuning.max_connections > 0 && registry.connection_count() >= tuning.max_connections {
63            metrics::counter!("ws_upgrades_refused_total").increment(1);
64            return axum::http::Response::builder()
65                .status(503)
66                .header("retry-after", "5")
67                .body(axum::body::Body::from("websocket capacity reached"))
68                .expect("static refusal");
69        }
70        // The SAME unified extraction as the HTTP boundary (pipeline):
71        // trace + tenant + credentials in one pass. The handshake
72        // authenticates once; gateway handlers see claims AND the resolved
73        // tenant, and the connection inherits the caller's trace identity.
74        let provenance = crate::pipeline::Provenance::from_headers(&headers, container).await;
75        tracing::debug!(
76            trace_id = %crate::observability::lean_telemetry::hex_encode(&provenance.trace.trace_id),
77            tenant = provenance.tenant.as_deref().map(|t| t.id.as_str()).unwrap_or(""),
78            "WS handshake provenance"
79        );
80        ws.on_upgrade(move |socket| {
81            handle_socket(
82                socket,
83                runtime,
84                registry,
85                provenance.claims,
86                provenance.tenant,
87                tuning,
88            )
89        })
90    };
91    get(handler)
92}
93
94/// Drive one upgraded socket to completion: register, pump, dispatch, drain.
95async fn handle_socket(
96    socket: WebSocket,
97    runtime: &'static GatewayRuntime,
98    registry: &'static ConnectionRegistry,
99    claims: Option<Arc<Claims>>,
100    tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
101    tuning: WsTuning,
102) {
103    let (mut sink, mut stream) = socket.split();
104
105    // Outbound queue: any task enqueues, this socket's writer drains.
106    // **Bounded** — the depth is the per-socket memory ceiling; a client
107    // that can't drain it gets evicted by the registry, never buffered
108    // without limit.
109    let (tx, mut rx) = mpsc::channel::<WsMessage>(tuning.outbound_buffer.max(1));
110    let id = registry.register(tx, claims.clone());
111    let client = WsClient::__new(id, registry, claims, tenant);
112
113    // One-shot signal: when the writer exits for *any* reason (peer error,
114    // server-initiated Close, or channel closed), the reader is unblocked so it
115    // stops polling the stream and runs on_disconnect + unregister. Without this,
116    // a server-initiated close would leave the reader blocked on stream.next()
117    // indefinitely if the peer never sends a Close echo.
118    let (close_tx, mut close_rx) = oneshot::channel::<()>();
119
120    // Writer half — owns the sink; exits when the queue closes or the peer
121    // dies. A periodic Ping (when configured) keeps NATs/proxies open and
122    // provokes pongs that feed the idle sweeper's `last_seen`.
123    let ping_every = tuning.ping_interval;
124    let writer = tokio::spawn(async move {
125        let mut ping = (!ping_every.is_zero()).then(|| {
126            let mut t = tokio::time::interval(ping_every);
127            t.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
128            t
129        });
130        loop {
131            let msg = if let Some(t) = ping.as_mut() {
132                tokio::select! {
133                    m = rx.recv() => m,
134                    _ = t.tick() => {
135                        if sink.send(Message::Ping(bytes::Bytes::new())).await.is_err() {
136                            break;
137                        }
138                        continue;
139                    }
140                }
141            } else {
142                rx.recv().await
143            };
144            let Some(msg) = msg else { break };
145            let frame = match msg {
146                WsMessage::Text(arc) => Message::Text(arc.as_ref().into()),
147                WsMessage::Ping => Message::Ping(bytes::Bytes::new()),
148                WsMessage::Close => {
149                    // Send close frame then exit immediately — do NOT loop back
150                    // to rx.recv(), which would keep the sink open indefinitely.
151                    let _ = sink.send(Message::Close(None)).await;
152                    break;
153                }
154            };
155            if sink.send(frame).await.is_err() {
156                break;
157            }
158        }
159        // Dropping close_tx signals the reader regardless of why we exited.
160        drop(close_tx);
161    });
162
163    (runtime.on_connect)(client.clone()).await;
164
165    // Reader half — routes inbound frames to subscribed handlers.
166    // Also watches for the writer-exit signal so a server-initiated close
167    // (WsMessage::Close enqueued by a handler) terminates the reader promptly.
168    loop {
169        tokio::select! {
170            biased;
171            // Writer exited (server-initiated close or peer write error).
172            _ = &mut close_rx => break,
173            frame = stream.next() => match frame {
174                None => break,
175                Some(Err(_)) => break,
176                Some(Ok(frame)) => {
177                    // Any inbound frame (including pongs from our pings)
178                    // proves the link is alive for the idle sweeper.
179                    registry.touch(id);
180                    match frame {
181                        Message::Text(text) => dispatch_event(runtime, &client, &text).await,
182                        Message::Binary(_) => { /* binary multiplexing not enabled */ }
183                        Message::Close(_) => break,
184                        Message::Ping(_) | Message::Pong(_) => { /* axum auto-replies to pings */ }
185                    }
186                }
187            }
188        }
189    }
190
191    (runtime.on_disconnect)(client.clone()).await;
192    registry.unregister(id);
193    writer.abort();
194}
195
196/// Parse one `{ "event": ..., "data": ... }` envelope and invoke its handler.
197/// Unknown events and malformed frames are ignored (a hostile client can't
198/// crash the dispatcher).
199async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
200    let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
201        return;
202    };
203    let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
204        return;
205    };
206    let Some(handler) = runtime.handler(event) else {
207        return;
208    };
209
210    let data = value
211        .get("data")
212        .cloned()
213        .unwrap_or(serde_json::Value::Null);
214    let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
215
216    metrics::counter!("ws_messages_in_total").increment(1);
217    // Handler errors stay at the transport edge — gateways own their own
218    // error-to-client signalling — but they are counted and logged so a
219    // misbehaving event handler is visible on dashboards.
220    if let Err(e) = handler(client.clone(), data_str).await {
221        metrics::counter!("ws_handler_errors_total").increment(1);
222        tracing::debug!(conn = client.id(), event, error = %e, "gateway handler error");
223    }
224}