arcly_http/realtime/ws.rs
1//! WebSocket boundary: upgrade, per-socket read/write pumps, event dispatch.
2//!
3//! This is the *only* module that touches `axum::extract::ws` — the analogue of
4//! [`crate::web::boundary`] for the real-time layer. Everything above it speaks
5//! arcly types ([`WsClient`], [`WsMessage`], [`GatewayRuntime`]).
6//!
7//! ## Per-connection model (no hot-path locks)
8//!
9//! ```text
10//! ┌─────────────────── handle_socket task ───────────────────┐
11//! socket ──>│ reader: stream.next() ─> dispatch(event) ─> handler fut │
12//! │ writer: rx.recv() ─> sink.send(frame) │
13//! └───────────────────────────────────────────────────────────┘
14//! ```
15//! The reader and writer run as independent halves of the split socket. Inbound
16//! frames are parsed and routed through the gateway's `dispatch` table (an
17//! immutable `&HashMap` — lock-free read). Outbound frames are produced by any
18//! task via the registry's sharded channels and drained by this socket's writer.
19
20use std::sync::Arc;
21
22use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
23use axum::http::HeaderMap;
24use axum::routing::{get, MethodRouter};
25use futures::{SinkExt, StreamExt};
26use tokio::sync::{mpsc, oneshot};
27
28use crate::core::engine::FrozenDiContainer;
29use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
30use crate::realtime::gateway::GatewayRuntime;
31use crate::web::context::Claims;
32
33/// Per-gateway runtime tuning, sourced from `LaunchConfig` at mount time.
34#[derive(Clone, Copy, Debug)]
35pub struct WsTuning {
36 /// Outbound queue depth per socket — the slow-client memory ceiling.
37 pub outbound_buffer: usize,
38 /// Hard cap on concurrent sockets across all gateways (`0` = unlimited);
39 /// beyond it upgrades are refused with `503` before any socket exists.
40 pub max_connections: usize,
41 /// Server→client Ping cadence (`ZERO` disables). Pings provoke pongs,
42 /// which feed the idle sweeper's `last_seen`.
43 pub ping_interval: std::time::Duration,
44}
45
46/// Build the axum `MethodRouter` that upgrades HTTP→WebSocket for one gateway.
47///
48/// If a `JwtService` has been provided in the DI container, the
49/// `Authorization: Bearer <token>` header is decoded during the WebSocket
50/// handshake and the resulting claims are threaded through to every `WsClient`
51/// so gateway handlers can call `client.claims()` for auth decisions.
52pub fn ws_route(
53 runtime: &'static GatewayRuntime,
54 registry: &'static ConnectionRegistry,
55 container: &'static FrozenDiContainer,
56 tuning: WsTuning,
57) -> MethodRouter {
58 let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
59 // Admission control happens BEFORE the upgrade — past the cap no
60 // socket, queue, or registry entry is ever created.
61 if tuning.max_connections > 0 && registry.connection_count() >= tuning.max_connections {
62 metrics::counter!("ws_upgrades_refused_total").increment(1);
63 return axum::http::Response::builder()
64 .status(503)
65 .header("retry-after", "5")
66 .body(axum::body::Body::from("websocket capacity reached"))
67 .expect("static refusal");
68 }
69 // The SAME unified extraction as the HTTP boundary (pipeline):
70 // trace + tenant + credentials in one pass. The handshake
71 // authenticates once; gateway handlers see claims AND the resolved
72 // tenant, and the connection inherits the caller's trace identity.
73 let provenance = crate::pipeline::Provenance::from_headers(&headers, container).await;
74 tracing::debug!(
75 trace_id = %crate::observability::lean_telemetry::hex_encode(&provenance.trace.trace_id),
76 tenant = provenance.tenant.as_deref().map(|t| t.id.as_str()).unwrap_or(""),
77 "WS handshake provenance"
78 );
79 ws.on_upgrade(move |socket| {
80 handle_socket(
81 socket,
82 runtime,
83 registry,
84 provenance.claims,
85 provenance.tenant,
86 tuning,
87 )
88 })
89 };
90 get(handler)
91}
92
93/// Drive one upgraded socket to completion: register, pump, dispatch, drain.
94async fn handle_socket(
95 socket: WebSocket,
96 runtime: &'static GatewayRuntime,
97 registry: &'static ConnectionRegistry,
98 claims: Option<Arc<Claims>>,
99 tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
100 tuning: WsTuning,
101) {
102 let (mut sink, mut stream) = socket.split();
103
104 // Outbound queue: any task enqueues, this socket's writer drains.
105 // **Bounded** — the depth is the per-socket memory ceiling; a client
106 // that can't drain it gets evicted by the registry, never buffered
107 // without limit.
108 let (tx, mut rx) = mpsc::channel::<WsMessage>(tuning.outbound_buffer.max(1));
109 let id = registry.register(tx, claims.clone());
110 let client = WsClient::__new(id, registry, claims, tenant);
111
112 // One-shot signal: when the writer exits for *any* reason (peer error,
113 // server-initiated Close, or channel closed), the reader is unblocked so it
114 // stops polling the stream and runs on_disconnect + unregister. Without this,
115 // a server-initiated close would leave the reader blocked on stream.next()
116 // indefinitely if the peer never sends a Close echo.
117 let (close_tx, mut close_rx) = oneshot::channel::<()>();
118
119 // Writer half — owns the sink; exits when the queue closes or the peer
120 // dies. A periodic Ping (when configured) keeps NATs/proxies open and
121 // provokes pongs that feed the idle sweeper's `last_seen`.
122 let ping_every = tuning.ping_interval;
123 let writer = tokio::spawn(async move {
124 let mut ping = (!ping_every.is_zero()).then(|| {
125 let mut t = tokio::time::interval(ping_every);
126 t.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
127 t
128 });
129 loop {
130 let msg = if let Some(t) = ping.as_mut() {
131 tokio::select! {
132 m = rx.recv() => m,
133 _ = t.tick() => {
134 if sink.send(Message::Ping(Vec::new())).await.is_err() {
135 break;
136 }
137 continue;
138 }
139 }
140 } else {
141 rx.recv().await
142 };
143 let Some(msg) = msg else { break };
144 let frame = match msg {
145 WsMessage::Text(arc) => Message::Text(String::from(arc.as_ref())),
146 WsMessage::Ping => Message::Ping(Vec::new()),
147 WsMessage::Close => {
148 // Send close frame then exit immediately — do NOT loop back
149 // to rx.recv(), which would keep the sink open indefinitely.
150 let _ = sink.send(Message::Close(None)).await;
151 break;
152 }
153 };
154 if sink.send(frame).await.is_err() {
155 break;
156 }
157 }
158 // Dropping close_tx signals the reader regardless of why we exited.
159 drop(close_tx);
160 });
161
162 (runtime.on_connect)(client.clone()).await;
163
164 // Reader half — routes inbound frames to subscribed handlers.
165 // Also watches for the writer-exit signal so a server-initiated close
166 // (WsMessage::Close enqueued by a handler) terminates the reader promptly.
167 loop {
168 tokio::select! {
169 biased;
170 // Writer exited (server-initiated close or peer write error).
171 _ = &mut close_rx => break,
172 frame = stream.next() => match frame {
173 None => break,
174 Some(Err(_)) => break,
175 Some(Ok(frame)) => {
176 // Any inbound frame (including pongs from our pings)
177 // proves the link is alive for the idle sweeper.
178 registry.touch(id);
179 match frame {
180 Message::Text(text) => dispatch_event(runtime, &client, &text).await,
181 Message::Binary(_) => { /* binary multiplexing not enabled */ }
182 Message::Close(_) => break,
183 Message::Ping(_) | Message::Pong(_) => { /* axum auto-replies to pings */ }
184 }
185 }
186 }
187 }
188 }
189
190 (runtime.on_disconnect)(client.clone()).await;
191 registry.unregister(id);
192 writer.abort();
193}
194
195/// Parse one `{ "event": ..., "data": ... }` envelope and invoke its handler.
196/// Unknown events and malformed frames are ignored (a hostile client can't
197/// crash the dispatcher).
198async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
199 let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
200 return;
201 };
202 let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
203 return;
204 };
205 let Some(handler) = runtime.handler(event) else {
206 return;
207 };
208
209 let data = value
210 .get("data")
211 .cloned()
212 .unwrap_or(serde_json::Value::Null);
213 let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
214
215 metrics::counter!("ws_messages_in_total").increment(1);
216 // Handler errors stay at the transport edge — gateways own their own
217 // error-to-client signalling — but they are counted and logged so a
218 // misbehaving event handler is visible on dashboards.
219 if let Err(e) = handler(client.clone(), data_str).await {
220 metrics::counter!("ws_handler_errors_total").increment(1);
221 tracing::debug!(conn = client.id(), event, error = %e, "gateway handler error");
222 }
223}