Skip to main content

pylon_runtime/
shard_ws.rs

1//! Bidirectional WebSocket server for real-time shards.
2//!
3//! Runs on its own port (typically `pylon_port + 3`). Each connection:
4//!
5//! 1. Parses the request path for `?shard=<id>&sid=<subscriber>`.
6//! 2. Looks up the shard in the [`DynShardRegistry`].
7//! 3. Runs the subscribe authorization hook.
8//! 4. Registers a [`SnapshotSink`] that writes binary frames to the socket.
9//! 5. Reads text/binary frames from the client and pushes them as inputs.
10//! 6. Cleans up on disconnect.
11//!
12//! Each client gets its own dedicated thread. For larger deployments,
13//! swap in an async runtime; for pylon's current scale, thread-per-conn
14//! is simpler and fine.
15
16use std::net::{TcpListener, TcpStream};
17use std::sync::{Arc, Mutex};
18use std::thread;
19use std::time::Duration;
20
21use pylon_auth::SessionStore;
22use pylon_realtime::{DynShardRegistry, ShardAuth, ShardError, SubscriberId};
23use tungstenite::{accept_hdr, handshake::server::Request, Message};
24
25use crate::ip_limit::IpConnCounter;
26
27// ---------------------------------------------------------------------------
28// Start
29// ---------------------------------------------------------------------------
30
31/// Run a WebSocket server that accepts shard connections.
32///
33/// Blocking. Spawn on a background thread.
34pub fn start_shard_ws_server(
35    registry: Arc<dyn DynShardRegistry>,
36    sessions: Arc<SessionStore>,
37    port: u16,
38) {
39    let listener = match TcpListener::bind(format!("0.0.0.0:{port}")) {
40        Ok(l) => l,
41        Err(e) => {
42            tracing::warn!("[shard-ws] failed to bind port {port}: {e}");
43            return;
44        }
45    };
46    tracing::warn!("[shard-ws] listening on ws://0.0.0.0:{port}");
47
48    // Per-IP cap so a single client can't open a swarm of shard WS
49    // connections to exhaust the per-thread resource budget. Games with
50    // many tabs/devices per household still get 64 concurrent shards.
51    let ip_counter = Arc::new(IpConnCounter::default());
52
53    for stream in listener.incoming() {
54        let stream = match stream {
55            Ok(s) => s,
56            Err(_) => continue,
57        };
58        let ip = match stream.peer_addr() {
59            Ok(addr) => addr.ip(),
60            Err(_) => continue,
61        };
62        let guard = match ip_counter.acquire(ip) {
63            Some(g) => g,
64            None => continue,
65        };
66        let registry = Arc::clone(&registry);
67        let sessions = Arc::clone(&sessions);
68        thread::spawn(move || {
69            // Holding `_guard` for the life of this thread (which lives for
70            // the full connection) is what ties the IP slot to the socket.
71            let _guard = guard;
72            if let Err(e) = handle_connection(stream, registry, sessions) {
73                tracing::warn!("[shard-ws] connection error: {e}");
74            }
75        });
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Per-connection handler
81// ---------------------------------------------------------------------------
82
83fn handle_connection(
84    stream: TcpStream,
85    registry: Arc<dyn DynShardRegistry>,
86    sessions: Arc<SessionStore>,
87) -> Result<(), String> {
88    // Capture the HTTP handshake so we can read the Request-URI and headers.
89    let params = std::sync::Arc::new(Mutex::new(HandshakeParams::default()));
90    let params_clone = Arc::clone(&params);
91
92    use tungstenite::handshake::server::{ErrorResponse, Response};
93    let ws = accept_hdr(
94        stream,
95        |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
96            let uri = req.uri().to_string();
97            let mut p = params_clone.lock().unwrap();
98            p.uri = uri;
99            let mut selected_protocol: Option<String> = None;
100            for (name, value) in req.headers() {
101                let lower = name.as_str().to_ascii_lowercase();
102                if lower == "authorization" {
103                    if let Ok(v) = value.to_str() {
104                        p.auth_header = Some(v.to_string());
105                    }
106                } else if lower == "sec-websocket-protocol" {
107                    // Accept a `bearer.<url-encoded-token>` subprotocol as an
108                    // alternative to the Authorization header. Browsers can't
109                    // set WebSocket headers directly, so this is how a web
110                    // client carries a bearer token without putting it in the
111                    // URL. Pick the first token that matches our prefix; echo
112                    // the exact chosen subprotocol back in the handshake
113                    // response, per RFC 6455 ยง11.3.4 (otherwise some browsers
114                    // refuse the connection).
115                    if let Ok(v) = value.to_str() {
116                        for proto in v.split(',').map(str::trim) {
117                            if let Some(encoded) = proto.strip_prefix("bearer.") {
118                                if let Ok(decoded) = urldecode_strict(encoded) {
119                                    p.bearer_from_subprotocol = Some(decoded);
120                                    selected_protocol = Some(proto.to_string());
121                                    break;
122                                }
123                            }
124                        }
125                    }
126                }
127            }
128            if let Some(chosen) = selected_protocol {
129                if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
130                    resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
131                }
132            }
133            Ok(resp)
134        },
135    )
136    .map_err(|e| format!("handshake: {e}"))?;
137
138    let params = params.lock().unwrap().clone();
139    let query = params
140        .uri
141        .split_once('?')
142        .map(|(_, q)| q.to_string())
143        .unwrap_or_default();
144
145    let shard_id = query_param(&query, "shard").ok_or("missing ?shard= parameter")?;
146    let sid = query_param(&query, "sid").unwrap_or_else(|| "anon".to_string());
147
148    // Resolve auth token. Preference order:
149    //   1. Authorization: Bearer ...   (native clients)
150    //   2. Sec-WebSocket-Protocol: bearer.<token>   (browsers)
151    //
152    // The legacy `?token=` query-string path was removed: it leaked the
153    // bearer token into proxy access logs, Referer headers, and browser
154    // history. All supported clients can send the subprotocol or header.
155    let token = params
156        .auth_header
157        .as_deref()
158        .and_then(|h| h.strip_prefix("Bearer "))
159        .map(|t| t.to_string())
160        .or_else(|| params.bearer_from_subprotocol.clone());
161    let auth_ctx = sessions.resolve(token.as_deref());
162    let shard_auth = ShardAuth {
163        user_id: auth_ctx.user_id.clone(),
164        is_admin: auth_ctx.is_admin,
165    };
166
167    let shard = registry
168        .get(&shard_id)
169        .ok_or_else(|| format!("shard \"{shard_id}\" not found"))?;
170
171    let ws = Arc::new(Mutex::new(ws));
172    let subscriber_id = SubscriberId::new(sid.clone());
173
174    // Build the sink: every snapshot broadcast becomes a WS binary frame.
175    let ws_for_sink = Arc::clone(&ws);
176    let sink: pylon_realtime::SnapshotSink = Box::new(move |tick, bytes| {
177        let mut payload = Vec::with_capacity(8 + bytes.len() + 2);
178        payload.extend_from_slice(&tick.to_be_bytes());
179        payload.extend_from_slice(bytes);
180        if let Ok(mut s) = ws_for_sink.lock() {
181            let _ = s.send(Message::Binary(payload.into()));
182        }
183    });
184
185    // Register the subscriber, respecting auth.
186    match shard.add_subscriber(subscriber_id.clone(), sink, &shard_auth) {
187        Ok(()) => {}
188        Err(ShardError::Unauthorized(reason)) => {
189            let _ = ws
190                .lock()
191                .unwrap()
192                .close(Some(tungstenite::protocol::CloseFrame {
193                    code: tungstenite::protocol::frame::coding::CloseCode::Policy,
194                    reason: format!("unauthorized: {reason}").into(),
195                }));
196            return Ok(());
197        }
198        Err(e) => {
199            let _ = ws
200                .lock()
201                .unwrap()
202                .close(Some(tungstenite::protocol::CloseFrame {
203                    code: tungstenite::protocol::frame::coding::CloseCode::Again,
204                    reason: e.to_string().into(),
205                }));
206            return Ok(());
207        }
208    }
209
210    // Read loop โ€” inbound messages from the client become shard inputs.
211    // Each message is JSON: {"input": ..., "client_seq"?: N}
212    let read_result = loop {
213        let msg = {
214            let mut s = match ws.lock() {
215                Ok(s) => s,
216                Err(_) => break Err("ws lock poisoned".to_string()),
217            };
218            match s.read() {
219                Ok(m) => m,
220                Err(tungstenite::Error::ConnectionClosed) => break Ok(()),
221                Err(tungstenite::Error::AlreadyClosed) => break Ok(()),
222                Err(e) => break Err(format!("ws read: {e}")),
223            }
224        };
225
226        match msg {
227            Message::Text(text) => {
228                process_input(&shard, &subscriber_id, &shard_auth, text.as_str());
229            }
230            Message::Binary(bytes) => {
231                let text = String::from_utf8_lossy(&bytes).to_string();
232                process_input(&shard, &subscriber_id, &shard_auth, &text);
233            }
234            Message::Ping(payload) => {
235                let _ = ws.lock().unwrap().send(Message::Pong(payload));
236            }
237            Message::Close(_) => break Ok(()),
238            _ => {}
239        }
240    };
241
242    // Clean up.
243    shard.remove_subscriber(&subscriber_id);
244    if let Err(e) = read_result {
245        Err(e)
246    } else {
247        Ok(())
248    }
249}
250
251fn process_input(
252    shard: &Arc<dyn pylon_realtime::DynShard>,
253    subscriber_id: &SubscriberId,
254    shard_auth: &ShardAuth,
255    text: &str,
256) {
257    // Envelope shape: { input, client_seq? }
258    let envelope: serde_json::Value = match serde_json::from_str(text) {
259        Ok(v) => v,
260        Err(_) => return,
261    };
262    let input = envelope
263        .get("input")
264        .cloned()
265        .unwrap_or(serde_json::Value::Null);
266    let client_seq = envelope.get("client_seq").and_then(|v| v.as_u64());
267    let input_str = serde_json::to_string(&input).unwrap_or_else(|_| "null".into());
268
269    let _ = shard.push_input_json(subscriber_id.clone(), &input_str, client_seq, shard_auth);
270}
271
272// ---------------------------------------------------------------------------
273// Helpers
274// ---------------------------------------------------------------------------
275
276#[derive(Default, Clone)]
277struct HandshakeParams {
278    uri: String,
279    auth_header: Option<String>,
280    bearer_from_subprotocol: Option<String>,
281}
282
283/// Strict percent-decode: fails on malformed input. Used for the WS
284/// subprotocol bearer token so we don't silently accept garbage.
285fn urldecode_strict(s: &str) -> Result<String, String> {
286    let mut out = Vec::with_capacity(s.len());
287    let bytes = s.as_bytes();
288    let mut i = 0;
289    while i < bytes.len() {
290        if bytes[i] == b'%' {
291            if i + 2 >= bytes.len() {
292                return Err("truncated percent-encoding".into());
293            }
294            let hi = (bytes[i + 1] as char)
295                .to_digit(16)
296                .ok_or("bad hex in percent-encoding")?;
297            let lo = (bytes[i + 2] as char)
298                .to_digit(16)
299                .ok_or("bad hex in percent-encoding")?;
300            out.push(((hi << 4) | lo) as u8);
301            i += 3;
302        } else if bytes[i] == b'+' {
303            out.push(b' ');
304            i += 1;
305        } else {
306            out.push(bytes[i]);
307            i += 1;
308        }
309    }
310    String::from_utf8(out).map_err(|_| "percent-encoded token is not valid UTF-8".into())
311}
312
313fn query_param(query: &str, key: &str) -> Option<String> {
314    for pair in query.split('&') {
315        let mut it = pair.splitn(2, '=');
316        let k = it.next()?;
317        let v = it.next().unwrap_or("");
318        if k == key {
319            return Some(url_decode(v));
320        }
321    }
322    None
323}
324
325fn url_decode(s: &str) -> String {
326    let mut out = String::with_capacity(s.len());
327    let bytes = s.as_bytes();
328    let mut i = 0;
329    while i < bytes.len() {
330        match bytes[i] {
331            b'+' => {
332                out.push(' ');
333                i += 1;
334            }
335            b'%' if i + 2 < bytes.len() => {
336                if let Ok(h) =
337                    u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
338                {
339                    out.push(h as char);
340                    i += 3;
341                } else {
342                    out.push(bytes[i] as char);
343                    i += 1;
344                }
345            }
346            b => {
347                out.push(b as char);
348                i += 1;
349            }
350        }
351    }
352    out
353}
354
355// Silence timeout warnings when clients hold connections open briefly.
356#[allow(dead_code)]
357fn apply_read_timeout(stream: &TcpStream, dur: Duration) {
358    let _ = stream.set_read_timeout(Some(dur));
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn query_param_parses_basic() {
367        assert_eq!(
368            query_param("shard=match1&sid=p1", "shard"),
369            Some("match1".to_string())
370        );
371        assert_eq!(
372            query_param("shard=match1&sid=p1", "sid"),
373            Some("p1".to_string())
374        );
375        assert_eq!(query_param("shard=match1", "missing"), None);
376    }
377
378    #[test]
379    fn query_param_url_decodes() {
380        assert_eq!(
381            query_param("name=hello%20world", "name"),
382            Some("hello world".to_string())
383        );
384    }
385}