arcly_http/realtime/ws.rs
1//! WebSocket boundary: upgrade, per-socket read/write pumps, event dispatch.
2//!
3//! This is the *only* module that touches `axum::extract::ws` — the analogue of
4//! [`crate::web::boundary`] for the real-time layer. Everything above it speaks
5//! arcly types ([`WsClient`], [`WsMessage`], [`GatewayRuntime`]).
6//!
7//! ## Per-connection model (no hot-path locks)
8//!
9//! ```text
10//! ┌─────────────────── handle_socket task ───────────────────┐
11//! socket ──>│ reader: stream.next() ─> dispatch(event) ─> handler fut │
12//! │ writer: rx.recv() ─> sink.send(frame) │
13//! └───────────────────────────────────────────────────────────┘
14//! ```
15//! The reader and writer run as independent halves of the split socket. Inbound
16//! frames are parsed and routed through the gateway's `dispatch` table (an
17//! immutable `&HashMap` — lock-free read). Outbound frames are produced by any
18//! task via the registry's sharded channels and drained by this socket's writer.
19
20use std::sync::Arc;
21
22use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
23use axum::http::HeaderMap;
24use axum::routing::{get, MethodRouter};
25use futures::{SinkExt, StreamExt};
26use tokio::sync::{mpsc, oneshot};
27
28use crate::core::engine::FrozenDiContainer;
29use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
30use crate::realtime::gateway::GatewayRuntime;
31use crate::web::context::Claims;
32
33/// Per-gateway runtime tuning, sourced from `LaunchConfig` at mount time.
34#[derive(Clone, Copy, Debug)]
35#[non_exhaustive]
36pub struct WsTuning {
37 /// Outbound queue depth per socket — the slow-client memory ceiling.
38 pub outbound_buffer: usize,
39 /// Hard cap on concurrent sockets across all gateways (`0` = unlimited);
40 /// beyond it upgrades are refused with `503` before any socket exists.
41 pub max_connections: usize,
42 /// Server→client Ping cadence (`ZERO` disables). Pings provoke pongs,
43 /// which feed the idle sweeper's `last_seen`.
44 pub ping_interval: std::time::Duration,
45}
46
47/// Build the axum `MethodRouter` that upgrades HTTP→WebSocket for one gateway.
48///
49/// If a `JwtService` has been provided in the DI container, the
50/// `Authorization: Bearer <token>` header is decoded during the WebSocket
51/// handshake and the resulting claims are threaded through to every `WsClient`
52/// so gateway handlers can call `client.claims()` for auth decisions.
53pub fn ws_route(
54 runtime: &'static GatewayRuntime,
55 registry: &'static ConnectionRegistry,
56 container: &'static FrozenDiContainer,
57 tuning: WsTuning,
58) -> MethodRouter {
59 let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
60 // Admission control happens BEFORE the upgrade — past the cap no
61 // socket, queue, or registry entry is ever created.
62 if tuning.max_connections > 0 && registry.connection_count() >= tuning.max_connections {
63 metrics::counter!("ws_upgrades_refused_total").increment(1);
64 return axum::http::Response::builder()
65 .status(503)
66 .header("retry-after", "5")
67 .body(axum::body::Body::from("websocket capacity reached"))
68 .expect("static refusal");
69 }
70 // The SAME unified extraction as the HTTP boundary (pipeline):
71 // trace + tenant + credentials in one pass. The handshake
72 // authenticates once; gateway handlers see claims AND the resolved
73 // tenant, and the connection inherits the caller's trace identity.
74 let provenance = crate::pipeline::Provenance::from_headers(&headers, container).await;
75 tracing::debug!(
76 trace_id = %crate::observability::lean_telemetry::hex_encode(&provenance.trace.trace_id),
77 tenant = provenance.tenant.as_deref().map(|t| t.id.as_str()).unwrap_or(""),
78 "WS handshake provenance"
79 );
80 ws.on_upgrade(move |socket| {
81 handle_socket(
82 socket,
83 runtime,
84 registry,
85 provenance.claims,
86 provenance.tenant,
87 tuning,
88 )
89 })
90 };
91 get(handler)
92}
93
94/// Drive one upgraded socket to completion: register, pump, dispatch, drain.
95async fn handle_socket(
96 socket: WebSocket,
97 runtime: &'static GatewayRuntime,
98 registry: &'static ConnectionRegistry,
99 claims: Option<Arc<Claims>>,
100 tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
101 tuning: WsTuning,
102) {
103 let (mut sink, mut stream) = socket.split();
104
105 // Outbound queue: any task enqueues, this socket's writer drains.
106 // **Bounded** — the depth is the per-socket memory ceiling; a client
107 // that can't drain it gets evicted by the registry, never buffered
108 // without limit.
109 let (tx, mut rx) = mpsc::channel::<WsMessage>(tuning.outbound_buffer.max(1));
110 let id = registry.register(tx, claims.clone());
111 let client = WsClient::__new(id, registry, claims, tenant);
112
113 // One-shot signal: when the writer exits for *any* reason (peer error,
114 // server-initiated Close, or channel closed), the reader is unblocked so it
115 // stops polling the stream and runs on_disconnect + unregister. Without this,
116 // a server-initiated close would leave the reader blocked on stream.next()
117 // indefinitely if the peer never sends a Close echo.
118 let (close_tx, mut close_rx) = oneshot::channel::<()>();
119
120 // Writer half — owns the sink; exits when the queue closes or the peer
121 // dies. A periodic Ping (when configured) keeps NATs/proxies open and
122 // provokes pongs that feed the idle sweeper's `last_seen`.
123 let ping_every = tuning.ping_interval;
124 let writer = tokio::spawn(async move {
125 let mut ping = (!ping_every.is_zero()).then(|| {
126 let mut t = tokio::time::interval(ping_every);
127 t.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
128 t
129 });
130 loop {
131 let msg = if let Some(t) = ping.as_mut() {
132 tokio::select! {
133 m = rx.recv() => m,
134 _ = t.tick() => {
135 if sink.send(Message::Ping(bytes::Bytes::new())).await.is_err() {
136 break;
137 }
138 continue;
139 }
140 }
141 } else {
142 rx.recv().await
143 };
144 let Some(msg) = msg else { break };
145 let frame = match msg {
146 WsMessage::Text(arc) => Message::Text(arc.as_ref().into()),
147 WsMessage::Ping => Message::Ping(bytes::Bytes::new()),
148 WsMessage::Close => {
149 // Send close frame then exit immediately — do NOT loop back
150 // to rx.recv(), which would keep the sink open indefinitely.
151 let _ = sink.send(Message::Close(None)).await;
152 break;
153 }
154 };
155 if sink.send(frame).await.is_err() {
156 break;
157 }
158 }
159 // Dropping close_tx signals the reader regardless of why we exited.
160 drop(close_tx);
161 });
162
163 (runtime.on_connect)(client.clone()).await;
164
165 // Reader half — routes inbound frames to subscribed handlers.
166 // Also watches for the writer-exit signal so a server-initiated close
167 // (WsMessage::Close enqueued by a handler) terminates the reader promptly.
168 loop {
169 tokio::select! {
170 biased;
171 // Writer exited (server-initiated close or peer write error).
172 _ = &mut close_rx => break,
173 frame = stream.next() => match frame {
174 None => break,
175 Some(Err(_)) => break,
176 Some(Ok(frame)) => {
177 // Any inbound frame (including pongs from our pings)
178 // proves the link is alive for the idle sweeper.
179 registry.touch(id);
180 match frame {
181 Message::Text(text) => dispatch_event(runtime, &client, &text).await,
182 Message::Binary(_) => { /* binary multiplexing not enabled */ }
183 Message::Close(_) => break,
184 Message::Ping(_) | Message::Pong(_) => { /* axum auto-replies to pings */ }
185 }
186 }
187 }
188 }
189 }
190
191 (runtime.on_disconnect)(client.clone()).await;
192 registry.unregister(id);
193 writer.abort();
194}
195
196/// Parse one `{ "event": ..., "data": ... }` envelope and invoke its handler.
197/// Unknown events and malformed frames are ignored (a hostile client can't
198/// crash the dispatcher).
199async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
200 let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
201 return;
202 };
203 let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
204 return;
205 };
206 let Some(handler) = runtime.handler(event) else {
207 return;
208 };
209
210 let data = value
211 .get("data")
212 .cloned()
213 .unwrap_or(serde_json::Value::Null);
214 let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
215
216 metrics::counter!("ws_messages_in_total").increment(1);
217 // Handler errors stay at the transport edge — gateways own their own
218 // error-to-client signalling — but they are counted and logged so a
219 // misbehaving event handler is visible on dashboards.
220 if let Err(e) = handler(client.clone(), data_str).await {
221 metrics::counter!("ws_handler_errors_total").increment(1);
222 tracing::debug!(conn = client.id(), event, error = %e, "gateway handler error");
223 }
224}