pylon_runtime/ws.rs
1use std::collections::{HashMap, HashSet};
2use std::net::{TcpListener, TcpStream};
3use std::sync::mpsc;
4use std::sync::{Arc, Mutex};
5use std::thread;
6use std::time::Duration;
7
8use pylon_auth::SessionStore;
9use pylon_sync::ChangeEvent;
10use tungstenite::handshake::server::{ErrorResponse, Request, Response};
11use tungstenite::{accept_hdr_with_config, protocol::WebSocketConfig, Message, WebSocket};
12
13use crate::ip_limit::IpConnCounter;
14
15// ---------------------------------------------------------------------------
16// CRDT subscription manager
17//
18// Per-client subscriptions to (entity, row_id) pairs. Lets the binary CRDT
19// broadcast filter to only the clients that asked, instead of fanning out
20// every CRDT write to every connected WS client.
21//
22// Two reverse maps so both hot paths are O(subscribers per row) and
23// O(rows per client): the broadcast looks up subscribers by row, the
24// disconnect cleanup walks rows by client.
25//
26// Subscriptions are explicit and ephemeral — a client subscribes when
27// useLoroDoc(entity, id) mounts, unsubscribes on unmount or disconnect.
28// Server doesn't persist subscriptions across reconnects; the client
29// re-sends them.
30// ---------------------------------------------------------------------------
31
32#[derive(Default)]
33struct SubsState {
34 /// (entity, row_id) → set of client_ids subscribed to that row.
35 by_row: HashMap<(String, String), HashSet<u64>>,
36 /// client_id → set of (entity, row_id) it subscribes to.
37 /// Inverted to make disconnect cleanup O(rows per client) instead of
38 /// O(total rows in by_row).
39 by_client: HashMap<u64, HashSet<(String, String)>>,
40}
41
42pub struct CrdtSubscriptions {
43 /// Single mutex covers both reverse maps so any pair of operations
44 /// (subscribe + unsubscribe across threads, broadcast + disconnect
45 /// cleanup) sees a consistent view. Two separate mutexes would let
46 /// `subscribe` land in `by_row` while a concurrent `unsubscribe_all`
47 /// snapshots `by_client` mid-update, leaving the maps divergent.
48 state: Mutex<SubsState>,
49}
50
51impl Default for CrdtSubscriptions {
52 fn default() -> Self {
53 Self {
54 state: Mutex::new(SubsState::default()),
55 }
56 }
57}
58
59impl CrdtSubscriptions {
60 pub fn new() -> Arc<Self> {
61 Arc::new(Self::default())
62 }
63
64 /// Register a client's interest in a row. Idempotent — re-subscribing
65 /// the same client to the same row is a no-op (HashSet semantics).
66 pub fn subscribe(&self, client_id: u64, entity: &str, row_id: &str) {
67 let key = (entity.to_string(), row_id.to_string());
68 let mut state = self.state.lock().unwrap();
69 state
70 .by_row
71 .entry(key.clone())
72 .or_default()
73 .insert(client_id);
74 state.by_client.entry(client_id).or_default().insert(key);
75 }
76
77 /// Drop one subscription. Cleans up empty maps so the working set
78 /// stays bounded — long-running connections that subscribe and
79 /// unsubscribe to many rows over their lifetime don't accumulate
80 /// orphan empty entries.
81 pub fn unsubscribe(&self, client_id: u64, entity: &str, row_id: &str) {
82 let key = (entity.to_string(), row_id.to_string());
83 let mut state = self.state.lock().unwrap();
84 if let Some(set) = state.by_row.get_mut(&key) {
85 set.remove(&client_id);
86 if set.is_empty() {
87 state.by_row.remove(&key);
88 }
89 }
90 if let Some(set) = state.by_client.get_mut(&client_id) {
91 set.remove(&key);
92 if set.is_empty() {
93 state.by_client.remove(&client_id);
94 }
95 }
96 }
97
98 /// Drop every subscription for a client (called on WS disconnect or
99 /// when a broadcast send fails for that client). Atomic over the
100 /// whole client's subscription set — broadcast snapshots taken
101 /// concurrently see the client either fully present or fully gone.
102 pub fn unsubscribe_all(&self, client_id: u64) {
103 let mut state = self.state.lock().unwrap();
104 let rows: Vec<(String, String)> = state
105 .by_client
106 .remove(&client_id)
107 .map(|set| set.into_iter().collect())
108 .unwrap_or_default();
109 for key in rows {
110 if let Some(set) = state.by_row.get_mut(&key) {
111 set.remove(&client_id);
112 if set.is_empty() {
113 state.by_row.remove(&key);
114 }
115 }
116 }
117 }
118
119 /// Snapshot the subscriber set for a row. Returns an owned `Vec`
120 /// rather than a guard so the broadcast hot path doesn't hold the
121 /// mutex during the per-client send loop.
122 pub fn subscribers(&self, entity: &str, row_id: &str) -> Vec<u64> {
123 let key = (entity.to_string(), row_id.to_string());
124 let state = self.state.lock().unwrap();
125 state
126 .by_row
127 .get(&key)
128 .map(|set| set.iter().copied().collect())
129 .unwrap_or_default()
130 }
131
132 /// Diagnostic: total number of (client, row) pairs.
133 pub fn total_subscriptions(&self) -> usize {
134 self.state
135 .lock()
136 .unwrap()
137 .by_row
138 .values()
139 .map(|s| s.len())
140 .sum()
141 }
142}
143
144/// Number of shards for distributing WebSocket clients.
145/// Must be a power of two for even modulo distribution.
146const NUM_SHARDS: usize = 16;
147
148/// Maximum number of outbound messages queued per shard. Once the broadcast
149/// worker thread falls this many behind, the OLDEST queued message is
150/// dropped to make room for the new one. That means slow subscribers can
151/// miss messages — but the alternative (unbounded queue) was OOM when a
152/// single stuck client blocked its shard worker.
153///
154/// Callers that need exact delivery should layer their own retry on top
155/// (the change-log cursor protocol already does this for sync).
156const BROADCAST_QUEUE_DEPTH: usize = 1024;
157
158/// Read timeout on each WebSocket read. Kept low so the mutex guarding the
159/// socket is released frequently, letting the broadcaster get its turn even
160/// if the client never sends anything. Previously this was 120s, which meant
161/// one quiet client could wedge the shard's writer for up to two minutes.
162const WS_READ_TIMEOUT: Duration = Duration::from_millis(200);
163
164/// One entry per connected client. The socket lives behind its OWN
165/// `Mutex`, not a shard-wide one, so the reader thread's blocking
166/// `socket.read()` doesn't hold a lock that covers every client in the
167/// same shard. The broadcaster iterates the client map (outer lock is
168/// brief — O(count of clients in shard)), then grabs each client's
169/// individual mutex to do the `socket.send`. Contention is now per-
170/// client instead of per-shard.
171type ClientSocket = Arc<Mutex<WebSocket<TcpStream>>>;
172
173/// A single shard holding a subset of WebSocket clients.
174///
175/// The outer `Mutex<HashMap>` is held only for insert/remove and while
176/// enumerating client handles — never across I/O.
177struct Shard {
178 clients: Mutex<HashMap<u64, ClientSocket>>,
179}
180
181impl Shard {
182 fn new() -> Self {
183 Self {
184 clients: Mutex::new(HashMap::new()),
185 }
186 }
187
188 fn add(&self, id: u64, ws: WebSocket<TcpStream>) -> ClientSocket {
189 let handle = Arc::new(Mutex::new(ws));
190 self.clients.lock().unwrap().insert(id, Arc::clone(&handle));
191 handle
192 }
193
194 fn remove(&self, id: u64) {
195 self.clients.lock().unwrap().remove(&id);
196 }
197
198 /// Send a message to all clients in this shard.
199 ///
200 /// Snapshot the client handles under the shard lock, drop the shard
201 /// lock, then contend only with per-client mutexes to do the writes.
202 /// This is what lets a reader thread hold its client's mutex for a
203 /// socket.read() without stalling broadcasts for the whole shard.
204 ///
205 /// `msg` is `Arc<str>` rather than `&str` so the caller can serialize
206 /// the JSON exactly once and share the same allocation across all
207 /// 16 shards. Per-client `Message::Text` still allocates an owned
208 /// String (tungstenite 0.24 requires it), but the broadcast no
209 /// longer pays N copies of the JSON across shard channels.
210 fn broadcast(&self, msg: &Arc<str>) {
211 let handles: Vec<(u64, ClientSocket)> = {
212 let clients = self.clients.lock().unwrap();
213 clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
214 };
215 let mut dead: Vec<u64> = Vec::new();
216 for (id, handle) in handles {
217 // `try_lock` would skip clients whose reader is currently
218 // blocked in read(); we prefer `lock()` here so the occasional
219 // broadcaster wait (bounded by the 200ms read timeout) doesn't
220 // drop the message for that client.
221 let mut guard = match handle.lock() {
222 Ok(g) => g,
223 Err(poisoned) => poisoned.into_inner(),
224 };
225 // Owned String per send is the tungstenite 0.24 contract.
226 // The clone here copies the string contents; sharing the
227 // raw bytes via Utf8Bytes would be the next-level
228 // optimization but requires a tungstenite version bump.
229 if guard.send(Message::Text((**msg).to_string())).is_err() {
230 dead.push(id);
231 }
232 }
233 if !dead.is_empty() {
234 let mut clients = self.clients.lock().unwrap();
235 for id in &dead {
236 clients.remove(id);
237 }
238 }
239 }
240
241 /// Send a binary frame to a SPECIFIC subset of this shard's clients.
242 /// Used by the per-client subscription path — `WsHub::broadcast_binary_to`
243 /// computes which ids each shard owns and calls this with just those.
244 ///
245 /// Same per-client lock pattern as `broadcast` / `broadcast_binary`,
246 /// just filtered up front instead of iterating the whole shard.
247 ///
248 /// Returns the list of client ids whose send failed so the caller
249 /// can also clear those ids from the CRDT subscription registry —
250 /// without that step a dead client's subscription entries linger
251 /// until the reader thread notices the EOF and runs unsubscribe_all,
252 /// which can take up to one read-timeout (200ms) longer than the
253 /// send-side death detection.
254 fn send_binary_to(&self, ids: &[u64], msg: &Arc<[u8]>) -> Vec<u64> {
255 let handles: Vec<(u64, ClientSocket)> = {
256 let clients = self.clients.lock().unwrap();
257 ids.iter()
258 .filter_map(|id| clients.get(id).map(|h| (*id, Arc::clone(h))))
259 .collect()
260 };
261 let mut dead: Vec<u64> = Vec::new();
262 for (id, handle) in handles {
263 let mut guard = match handle.lock() {
264 Ok(g) => g,
265 Err(poisoned) => poisoned.into_inner(),
266 };
267 if guard.send(Message::Binary(msg.to_vec())).is_err() {
268 dead.push(id);
269 }
270 }
271 if !dead.is_empty() {
272 let mut clients = self.clients.lock().unwrap();
273 for id in &dead {
274 clients.remove(id);
275 }
276 }
277 dead
278 }
279
280 /// Binary fanout for CRDT updates. Same per-client lock pattern as
281 /// `broadcast` above; the only difference is `Message::Binary` and
282 /// the payload is `Arc<[u8]>` so a single Loro snapshot allocates
283 /// once and the per-client send pays a refcount bump + the
284 /// tungstenite-required Vec clone.
285 fn broadcast_binary(&self, msg: &Arc<[u8]>) {
286 let handles: Vec<(u64, ClientSocket)> = {
287 let clients = self.clients.lock().unwrap();
288 clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
289 };
290 let mut dead: Vec<u64> = Vec::new();
291 for (id, handle) in handles {
292 let mut guard = match handle.lock() {
293 Ok(g) => g,
294 Err(poisoned) => poisoned.into_inner(),
295 };
296 if guard.send(Message::Binary(msg.to_vec())).is_err() {
297 dead.push(id);
298 }
299 }
300 if !dead.is_empty() {
301 let mut clients = self.clients.lock().unwrap();
302 for id in &dead {
303 clients.remove(id);
304 }
305 }
306 }
307
308 fn count(&self) -> usize {
309 self.clients.lock().unwrap().len()
310 }
311}
312
313/// High-performance WebSocket broadcast hub with sharded client storage.
314///
315/// Supports 10k+ concurrent connections with bounded thread count.
316/// Uses NUM_SHARDS (16) shards to reduce lock contention.
317///
318/// Architecture:
319/// - Client connections are assigned to shards via round-robin (id % NUM_SHARDS).
320/// - Each shard has a dedicated broadcast worker thread that consumes from a channel.
321/// - Broadcast calls are non-blocking for the caller: they push to each shard's channel
322/// and return immediately.
323/// - Read-side threads use 64KB stacks (vs 2-8MB default) to keep memory bounded.
324/// - Total thread count: NUM_SHARDS broadcast workers + 1 per connected client (with
325/// minimal stack), plus the accept thread.
326pub struct WsHub {
327 shards: Vec<Arc<Shard>>,
328 next_id: Mutex<u64>,
329 /// Bounded-capacity senders for each shard's broadcast worker. When
330 /// a send would block because the queue is full, `broadcast_raw` drains
331 /// the oldest queued messages so new ones aren't lost to a stuck worker.
332 ///
333 /// Carries `Arc<str>` so a single broadcast event allocates the JSON
334 /// once and the 16 shard sends are cheap refcount bumps. Was a 16×
335 /// String clone hotspot under high write rates with thousands of
336 /// subscribers per shard.
337 broadcast_txs: Vec<mpsc::SyncSender<Arc<str>>>,
338 /// Matching receivers are held by each worker thread and also exposed
339 /// here so the "drop oldest" fallback can drain them on full. Keeping
340 /// the receiver handle alongside the sender is only safe because mpsc
341 /// lets multiple clones share a queue — here we only consume via the
342 /// worker, and the sender-side uses `try_send` + drain retry.
343 #[allow(dead_code)]
344 queue_depth: usize,
345 /// Per-client CRDT subscriptions. Reader threads register `(entity,
346 /// row_id)` pairs as the client mounts/unmounts useLoroDoc hooks;
347 /// the binary CRDT broadcast path uses `subscribers()` to filter the
348 /// fanout. Wrapped in Arc so the notifier (which holds `Arc<WsHub>`)
349 /// can read the subscriber set without taking an extra lock layer.
350 subscriptions: Arc<CrdtSubscriptions>,
351}
352
353impl WsHub {
354 pub fn new() -> Arc<Self> {
355 let mut shards = Vec::with_capacity(NUM_SHARDS);
356 let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
357
358 for i in 0..NUM_SHARDS {
359 let shard = Arc::new(Shard::new());
360 // Bounded queue — if a broadcast worker stalls, `try_send` fails
361 // with Full and `broadcast_raw` drops the oldest to make room.
362 let (tx, rx) = mpsc::sync_channel::<Arc<str>>(BROADCAST_QUEUE_DEPTH);
363
364 let shard_clone = Arc::clone(&shard);
365 thread::Builder::new()
366 .name(format!("ws-broadcast-{i}"))
367 .spawn(move || {
368 while let Ok(msg) = rx.recv() {
369 shard_clone.broadcast(&msg);
370 }
371 })
372 .expect("Failed to spawn broadcast worker");
373
374 shards.push(shard);
375 broadcast_txs.push(tx);
376 }
377
378 Arc::new(Self {
379 shards,
380 next_id: Mutex::new(0),
381 broadcast_txs,
382 queue_depth: BROADCAST_QUEUE_DEPTH,
383 subscriptions: CrdtSubscriptions::new(),
384 })
385 }
386
387 /// Access the per-client CRDT subscription registry. The notifier
388 /// looks up subscribers via `subscriptions().subscribers(entity, row)`
389 /// and feeds them to `broadcast_binary_to`.
390 pub fn subscriptions(&self) -> &Arc<CrdtSubscriptions> {
391 &self.subscriptions
392 }
393
394 /// Broadcast a change event to ALL connected clients across all shards.
395 /// Non-blocking: pushes to each shard's channel and returns immediately.
396 ///
397 /// Serializes the event JSON exactly once into an `Arc<str>` and
398 /// shares it across the 16 shard senders. Each shard's worker
399 /// thread receives the same Arc and pays only a refcount bump.
400 pub fn broadcast(&self, event: &ChangeEvent) {
401 let json = match serde_json::to_string(event) {
402 Ok(j) => j,
403 Err(_) => return,
404 };
405 let shared: Arc<str> = Arc::from(json.into_boxed_str());
406 self.broadcast_shared(shared);
407 }
408
409 /// Broadcast a raw string message to all clients (used for presence updates).
410 pub fn broadcast_presence(&self, msg: &str) {
411 let shared: Arc<str> = Arc::from(msg.to_string().into_boxed_str());
412 self.broadcast_shared(shared);
413 }
414
415 /// Broadcast a binary frame to every connected client across all
416 /// shards. Used for CRDT updates (see `pylon_router::encode_crdt_frame`
417 /// for the wire shape). The bytes are wrapped in an `Arc` so each
418 /// shard's per-client fanout shares one allocation; the per-send
419 /// `to_vec()` cost is the tungstenite 0.24 contract.
420 ///
421 /// Synchronous fanout — iterates shards directly rather than going
422 /// through the per-shard mpsc workers. CRDT writes happen at most
423 /// once per logical mutation so the throughput shape is "occasional
424 /// burst" not "every keystroke", and direct fanout avoids growing a
425 /// second per-shard channel (Arc<[u8]> can't share the Arc<str>
426 /// channel without an enum, which costs more than the bypass).
427 pub fn broadcast_binary(&self, bytes: Vec<u8>) {
428 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
429 for shard in &self.shards {
430 shard.broadcast_binary(&shared);
431 }
432 }
433
434 /// Send a binary frame to a specific subset of client IDs only.
435 /// Used by the CRDT broadcast path to fan out only to clients
436 /// subscribed to the row that just changed (instead of every
437 /// connected client). Routes each id to its owning shard via
438 /// `id % NUM_SHARDS`.
439 ///
440 /// `client_ids` typically comes from `CrdtSubscriptions::subscribers`.
441 /// An empty list is a no-op — the row had no subscribers, so the
442 /// CRDT write is durable on the server but no client sees the
443 /// binary frame (they'll learn about the change via the JSON
444 /// change-event broadcast which always fires).
445 pub fn broadcast_binary_to(&self, client_ids: &[u64], bytes: Vec<u8>) {
446 if client_ids.is_empty() {
447 return;
448 }
449 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
450 // Group ids by shard so each shard's per-client lock is only
451 // grabbed once even if many subscribers landed in the same one.
452 let mut by_shard: Vec<Vec<u64>> = (0..NUM_SHARDS).map(|_| Vec::new()).collect();
453 for id in client_ids {
454 by_shard[(*id as usize) % NUM_SHARDS].push(*id);
455 }
456 for (idx, ids) in by_shard.iter().enumerate() {
457 if ids.is_empty() {
458 continue;
459 }
460 for dead_id in self.shards[idx].send_binary_to(ids, &shared) {
461 // Drop the dead client's subscription entries too —
462 // otherwise they leak until the reader thread's read
463 // timeout fires and runs unsubscribe_all on its own,
464 // and a future broadcast might re-attempt the dead id.
465 self.subscriptions.unsubscribe_all(dead_id);
466 }
467 }
468 }
469
470 /// Send a binary frame to a single client by id. Used by the
471 /// subscribe path: when a client subscribes to a row, the server
472 /// immediately ships the current snapshot so the new subscriber
473 /// has the up-to-date state without waiting for the next write.
474 pub fn send_binary_to_one(&self, client_id: u64, bytes: Vec<u8>) {
475 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
476 let shard_idx = (client_id as usize) % NUM_SHARDS;
477 for dead_id in self.shards[shard_idx].send_binary_to(&[client_id], &shared) {
478 self.subscriptions.unsubscribe_all(dead_id);
479 }
480 }
481
482 /// Internal: fan out a single shared message to every shard worker.
483 ///
484 /// Uses `try_send`; on full we log once (per call) and drop the message
485 /// for that shard. Previously the channel was unbounded, so a stuck
486 /// worker thread would grow memory until OOM. The new bounded queue
487 /// means a slow/stuck subscriber at worst loses broadcast events —
488 /// correctness for critical data still comes through the change-log
489 /// cursor on a reconnect.
490 fn broadcast_shared(&self, msg: Arc<str>) {
491 for tx in &self.broadcast_txs {
492 match tx.try_send(Arc::clone(&msg)) {
493 Ok(()) => {}
494 Err(mpsc::TrySendError::Full(_)) => {
495 tracing::warn!("[ws] broadcast queue full — dropping event for one shard");
496 }
497 Err(mpsc::TrySendError::Disconnected(_)) => {
498 // Worker exited (shutdown). Silent.
499 }
500 }
501 }
502 }
503
504 /// Assign a client to a shard via round-robin and register it.
505 /// Returns `(id, socket_handle)` — the caller keeps the handle and uses
506 /// it for reads; the shard also keeps an Arc clone for broadcasts.
507 fn add_client(&self, ws: WebSocket<TcpStream>) -> (u64, ClientSocket) {
508 let mut next_id = self.next_id.lock().unwrap();
509 let id = *next_id;
510 *next_id += 1;
511 let shard_idx = (id as usize) % NUM_SHARDS;
512 let handle = self.shards[shard_idx].add(id, ws);
513 (id, handle)
514 }
515
516 fn remove_client(&self, id: u64) {
517 let shard_idx = (id as usize) % NUM_SHARDS;
518 self.shards[shard_idx].remove(id);
519 }
520
521 /// Total number of connected clients across all shards.
522 pub fn client_count(&self) -> usize {
523 self.shards.iter().map(|s| s.count()).sum()
524 }
525}
526
527/// Snapshot fetcher: given the caller's auth context + `(entity,
528/// row_id)`, return the encoded binary CRDT frame for the row's
529/// current state, or `None` if either the caller can't read the row
530/// (read policy denies) or the row has no snapshot (uninitialized
531/// CRDT or non-CRDT entity).
532///
533/// Auth context is passed in (rather than checked at the WS layer)
534/// because the policy engine + DataStore handles live in the runtime
535/// crate. Without this check an authenticated client could subscribe
536/// to any `(entity, row_id)` and receive every binary CRDT frame
537/// even for rows their query policy would reject — a silent read-
538/// policy bypass.
539///
540/// Wrapped in an Arc<dyn Fn> so the runtime can build it once, capturing
541/// the LoroStore + PolicyEngine handles, and hand the same closure to
542/// every accepted connection.
543pub type SnapshotFetcher =
544 Arc<dyn Fn(&pylon_auth::AuthContext, &str, &str) -> Option<Vec<u8>> + Send + Sync>;
545
546/// Start the WebSocket server on the given port.
547///
548/// The accept loop runs on the calling thread (blocking). Each accepted
549/// connection spawns a lightweight reader thread with a 64KB stack.
550/// Broadcast writes are handled by the shard worker threads, not by
551/// per-client threads.
552///
553/// The session store is required: every connection must present a valid
554/// bearer token (Authorization header or `bearer.<token>` subprotocol —
555/// browsers can't set WS headers directly). Previously the notifier hub
556/// accepted any connection and streamed every ChangeEvent/presence event
557/// to it, which was a silent read-policy bypass.
558///
559/// `snapshot_fetcher` is optional — when present, the reader will ship
560/// the current CRDT snapshot to the subscribing client immediately on
561/// `crdt-subscribe`, so the new tab sees the latest converged state
562/// without waiting for the next write. When absent, subscribe is still
563/// recorded but the catch-up frame is skipped.
564pub fn start_ws_server(
565 hub: Arc<WsHub>,
566 sessions: Arc<SessionStore>,
567 port: u16,
568 snapshot_fetcher: Option<SnapshotFetcher>,
569) {
570 let addr = format!("0.0.0.0:{port}");
571 let listener = match TcpListener::bind(&addr) {
572 Ok(l) => l,
573 Err(e) => {
574 tracing::warn!("[ws] Failed to bind on {addr}: {e}");
575 return;
576 }
577 };
578
579 tracing::warn!(
580 "[ws] WebSocket server listening on ws://localhost:{port} (sharded, {NUM_SHARDS} shards)"
581 );
582
583 let ip_counter = Arc::new(IpConnCounter::default());
584
585 for stream in listener.incoming() {
586 let stream = match stream {
587 Ok(s) => s,
588 Err(_) => continue,
589 };
590
591 // Per-IP connection cap: reject BEFORE the handshake so a cheap
592 // connect storm doesn't force us through tungstenite's HTTP parse
593 // and the session-resolve round trip. The guard is dropped when
594 // the reader thread exits (or fails to start), freeing the slot.
595 let ip = match stream.peer_addr() {
596 Ok(addr) => addr.ip(),
597 Err(_) => continue,
598 };
599 let guard = match ip_counter.acquire(ip) {
600 Some(g) => g,
601 None => {
602 // Ignore: let the client re-try after an existing connection
603 // closes. Previously an IP could open unbounded connections
604 // and each one spawned a thread + held a per-client mutex.
605 continue;
606 }
607 };
608
609 let hub = Arc::clone(&hub);
610 let sessions = Arc::clone(&sessions);
611 let fetcher = snapshot_fetcher.clone();
612 // Spawn a reader thread per client with a small stack.
613 // 64KB stack * 10k connections = ~640MB, vs 2-8MB default * 10k = 20-80GB.
614 let spawn_result = thread::Builder::new()
615 .name("ws-client".into())
616 .stack_size(64 * 1024)
617 .spawn(move || {
618 // Holding `guard` for the life of the connection thread is
619 // what makes the decrement-on-disconnect contract work. Not
620 // `let _ = guard;` — that drops immediately.
621 let _conn_slot = guard;
622 handle_ws_connection(hub, sessions, stream, fetcher);
623 });
624 if spawn_result.is_err() {
625 // Thread creation failed — guard is already dropped here, slot
626 // returned. We deliberately don't call `continue` before the
627 // spawn: we've paid the acquire cost and want to avoid leaking
628 // a slot under transient thread-limit pressure.
629 }
630 }
631}
632
633/// Handle a single WebSocket client connection.
634///
635/// Sets a read timeout to prevent zombie threads on dead connections.
636/// Handles ping/pong for keepalive, presence/topic message relay,
637/// and clean disconnect with presence broadcast.
638fn handle_ws_connection(
639 hub: Arc<WsHub>,
640 sessions: Arc<SessionStore>,
641 stream: TcpStream,
642 snapshot_fetcher: Option<SnapshotFetcher>,
643) {
644 // Short read timeout bounds how long the PER-CLIENT mutex is held
645 // while this thread is blocked in socket.read(). Each client now has
646 // its own mutex (not a shard-wide one), so a quiet client only stalls
647 // the broadcaster when it's broadcasting to THAT specific client —
648 // other clients in the same shard proceed without contention.
649 stream.set_read_timeout(Some(WS_READ_TIMEOUT)).ok();
650 // Also cap write time. A stuck kernel send (slow client, full send
651 // buffer, dropped packets) would otherwise stall the shard's
652 // broadcast worker holding this client's mutex — backpressure
653 // becomes head-of-line blocking for everyone. Capped at 5s; slow
654 // clients get disconnected rather than stalling the hub.
655 stream.set_write_timeout(Some(WS_READ_TIMEOUT)).ok();
656
657 // Extract the bearer token from the handshake, preferring the
658 // Authorization header (native clients) and falling back to the
659 // `bearer.<token>` WebSocket subprotocol (browsers). We only learn
660 // whether the token is valid AFTER accept_hdr completes, since the
661 // header callback must return synchronously with a Response.
662 let token_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
663 let slot_for_cb = Arc::clone(&token_slot);
664 // Cap WebSocket frame size to bound memory per connection. The
665 // tungstenite default (64 MiB) is too generous — a single client
666 // can shovel huge frames and starve other connections. The cap
667 // applies BIDIRECTIONALLY (server-sent CRDT snapshots are
668 // checked against it too), so the default must accommodate the
669 // largest legitimate snapshot — 16 MiB covers Loro docs with
670 // long histories. Operators tune via PYLON_WS_MAX_FRAME (bytes)
671 // when they have unusually large or unusually small docs.
672 let max_frame: usize = std::env::var("PYLON_WS_MAX_FRAME")
673 .ok()
674 .and_then(|v| v.parse().ok())
675 .unwrap_or(16 * 1024 * 1024);
676 let ws_config = WebSocketConfig {
677 max_message_size: Some(max_frame),
678 max_frame_size: Some(max_frame),
679 ..Default::default()
680 };
681 let ws = match accept_hdr_with_config(
682 stream,
683 move |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
684 let mut chosen_protocol: Option<String> = None;
685 let mut auth: Option<String> = None;
686 for (name, value) in req.headers() {
687 let lower = name.as_str().to_ascii_lowercase();
688 if lower == "authorization" {
689 if let Ok(v) = value.to_str() {
690 if let Some(tok) = v.strip_prefix("Bearer ") {
691 auth = Some(tok.to_string());
692 }
693 }
694 } else if lower == "sec-websocket-protocol" {
695 if let Ok(v) = value.to_str() {
696 for proto in v.split(',').map(str::trim) {
697 if let Some(encoded) = proto.strip_prefix("bearer.") {
698 if let Some(decoded) = percent_decode_token(encoded) {
699 auth = auth.or(Some(decoded));
700 chosen_protocol = Some(proto.to_string());
701 break;
702 }
703 }
704 }
705 }
706 }
707 }
708 // RFC 6455 §11.3.4 — echo the chosen subprotocol in the response or
709 // browsers will refuse the connection.
710 if let Some(chosen) = chosen_protocol {
711 if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
712 resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
713 }
714 }
715 *slot_for_cb.lock().unwrap() = auth;
716 Ok(resp)
717 },
718 Some(ws_config),
719 ) {
720 Ok(ws) => ws,
721 Err(_) => return,
722 };
723
724 // Reject unauthenticated or invalid-token handshakes AFTER accept —
725 // tungstenite's handshake callback can't easily return a 401 without
726 // a custom error response, and we already have the socket open for
727 // a clean close frame.
728 let token = token_slot.lock().unwrap().clone();
729 let auth_ctx = sessions.resolve(token.as_deref());
730 if auth_ctx.user_id.is_none() && !auth_ctx.is_admin {
731 let mut ws = ws;
732 let _ = ws.close(Some(tungstenite::protocol::CloseFrame {
733 code: tungstenite::protocol::frame::coding::CloseCode::Policy,
734 reason: "unauthorized: bearer token required".into(),
735 }));
736 return;
737 }
738
739 let (client_id, socket_handle) = hub.add_client(ws);
740
741 loop {
742 // Lock this client's socket mutex only for the duration of the
743 // read. With a 5s read timeout, broadcasters waiting to send to
744 // THIS client wait at most 5s. Other clients are never blocked
745 // by this lock — they have their own.
746 let msg = {
747 let mut guard = match socket_handle.lock() {
748 Ok(g) => g,
749 Err(poisoned) => poisoned.into_inner(),
750 };
751 guard.read()
752 };
753
754 match msg {
755 Ok(Message::Text(text)) => {
756 // Parse once and dispatch on the type field instead of
757 // matching prefix bytes — that approach silently dropped
758 // valid JSON with whitespace, key reordering, or any
759 // other formatting variation. Non-object / no-`type`
760 // messages are ignored.
761 let parsed: serde_json::Value = match serde_json::from_str(&text) {
762 Ok(v) => v,
763 Err(_) => continue,
764 };
765 let kind = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
766 match kind {
767 "presence" | "topic" => {
768 // Stamp the authenticated sender server-side,
769 // overriding any client-provided `from`. Without
770 // this, any client could spoof presence/topic
771 // events as another user — every connected
772 // client would see a forged "alice typed…"
773 // message attributed to alice.
774 let mut stamped = parsed.clone();
775 if let Some(obj) = stamped.as_object_mut() {
776 let from = auth_ctx
777 .user_id
778 .clone()
779 .unwrap_or_else(|| "admin".to_string());
780 obj.insert("from".into(), serde_json::Value::String(from));
781 }
782 hub.broadcast_presence(&stamped.to_string());
783 }
784 "crdt-subscribe" | "crdt-unsubscribe" => handle_crdt_control(
785 &hub,
786 client_id,
787 &auth_ctx,
788 kind,
789 &parsed,
790 snapshot_fetcher.as_ref(),
791 ),
792 _ => {}
793 }
794 }
795 Ok(Message::Ping(data)) => {
796 // Respond with pong to keep the connection alive.
797 if let Ok(mut guard) = socket_handle.lock() {
798 let _ = guard.send(Message::Pong(data));
799 }
800 }
801 Ok(Message::Close(_)) => {
802 // Drop every CRDT subscription this client held BEFORE
803 // remove_client so the broadcast path can never look up
804 // a stale client_id between the two ops.
805 hub.subscriptions.unsubscribe_all(client_id);
806 hub.remove_client(client_id);
807 let disconnect = serde_json::json!({
808 "type": "presence",
809 "event": "disconnect",
810 "clientId": client_id,
811 });
812 hub.broadcast_presence(&disconnect.to_string());
813 break;
814 }
815 Err(tungstenite::Error::Io(io_err))
816 if io_err.kind() == std::io::ErrorKind::WouldBlock
817 || io_err.kind() == std::io::ErrorKind::TimedOut =>
818 {
819 // Read timed out — this is EXPECTED with the short
820 // timeout. In theory the mutex is released between
821 // iterations, but `std::sync::Mutex` is not fair: a tight
822 // loop of lock→read→unlock→lock starves the broadcaster
823 // that's been waiting on the same mutex. Explicitly sleep
824 // for a tick so the broadcaster gets scheduled. 1ms is
825 // long enough to hand off, short enough that client→server
826 // latency stays sub-5ms.
827 std::thread::sleep(std::time::Duration::from_millis(1));
828 continue;
829 }
830 Err(_) => {
831 hub.subscriptions.unsubscribe_all(client_id);
832 hub.remove_client(client_id);
833 let disconnect = serde_json::json!({
834 "type": "presence",
835 "event": "disconnect",
836 "clientId": client_id,
837 });
838 hub.broadcast_presence(&disconnect.to_string());
839 break;
840 }
841 _ => {}
842 }
843 }
844}
845
846/// Apply a parsed `crdt-subscribe` / `crdt-unsubscribe` control
847/// message. Both messages have the shape:
848///
849/// { "type": "crdt-subscribe", "entity": "<E>", "rowId": "<id>" }
850/// { "type": "crdt-unsubscribe", "entity": "<E>", "rowId": "<id>" }
851///
852/// On subscribe the snapshot fetcher checks read policy for the
853/// caller's auth context — if the caller can't read the row we
854/// register no subscription and ship nothing back, so a malicious
855/// client can't peek at a row their query policy would block by
856/// just subscribing to its CRDT stream.
857///
858/// Malformed messages are silently dropped — there's no client-visible
859/// ACK protocol, so a typo in the payload would just look like a
860/// row that never receives updates. Logging would invite a noise
861/// channel for misbehaving clients.
862fn handle_crdt_control(
863 hub: &Arc<WsHub>,
864 client_id: u64,
865 auth_ctx: &pylon_auth::AuthContext,
866 kind: &str,
867 parsed: &serde_json::Value,
868 snapshot_fetcher: Option<&SnapshotFetcher>,
869) {
870 let entity = match parsed.get("entity").and_then(|v| v.as_str()) {
871 Some(e) if !e.is_empty() => e,
872 _ => return,
873 };
874 let row_id = match parsed
875 .get("rowId")
876 .or_else(|| parsed.get("row_id"))
877 .and_then(|v| v.as_str())
878 {
879 Some(r) if !r.is_empty() => r,
880 _ => return,
881 };
882
883 match kind {
884 "crdt-subscribe" => {
885 // Authz check happens INSIDE the fetcher (it has access to
886 // the policy engine + DataStore). When a fetcher is wired
887 // and returns None, the caller is either denied or the row
888 // doesn't exist — in both cases we refuse to register the
889 // subscription so a denied caller can't silently hold an
890 // open slot waiting for future writes.
891 //
892 // When no fetcher is wired (test harnesses, future
893 // workers backend without DataStore access) we trust the
894 // caller and register without the auth gate. Production
895 // server.rs always wires one, so this loophole is
896 // unreachable in deployed configurations.
897 let snapshot = snapshot_fetcher.and_then(|f| f(auth_ctx, entity, row_id));
898 let allow_subscribe = snapshot_fetcher.is_none() || snapshot.is_some();
899 if allow_subscribe {
900 hub.subscriptions.subscribe(client_id, entity, row_id);
901 if let Some(bytes) = snapshot {
902 hub.send_binary_to_one(client_id, bytes);
903 }
904 }
905 }
906 "crdt-unsubscribe" => {
907 hub.subscriptions.unsubscribe(client_id, entity, row_id);
908 }
909 _ => {}
910 }
911}
912
913/// Strict percent-decode for the `bearer.<token>` subprotocol. Returns
914/// `None` on any malformed byte rather than silently passing garbage
915/// through to the session store (which would just fail to resolve and
916/// look like a plain unauth attempt).
917fn percent_decode_token(s: &str) -> Option<String> {
918 let bytes = s.as_bytes();
919 let mut out = Vec::with_capacity(bytes.len());
920 let mut i = 0;
921 while i < bytes.len() {
922 match bytes[i] {
923 b'%' => {
924 if i + 2 >= bytes.len() {
925 return None;
926 }
927 let hi = (bytes[i + 1] as char).to_digit(16)?;
928 let lo = (bytes[i + 2] as char).to_digit(16)?;
929 out.push(((hi << 4) | lo) as u8);
930 i += 3;
931 }
932 b'+' => {
933 out.push(b' ');
934 i += 1;
935 }
936 b => {
937 out.push(b);
938 i += 1;
939 }
940 }
941 }
942 String::from_utf8(out).ok()
943}
944
945#[cfg(test)]
946mod tests {
947 use super::*;
948
949 #[test]
950 fn shard_count_starts_at_zero() {
951 let shard = Shard::new();
952 assert_eq!(shard.count(), 0);
953 }
954
955 #[test]
956 fn hub_starts_with_zero_clients() {
957 let hub = WsHub::new();
958 assert_eq!(hub.client_count(), 0);
959 }
960
961 #[test]
962 fn broadcast_to_empty_hub_doesnt_panic() {
963 let hub = WsHub::new();
964 let event = ChangeEvent {
965 seq: 1,
966 entity: "Test".into(),
967 row_id: "1".into(),
968 kind: pylon_sync::ChangeKind::Insert,
969 data: None,
970 timestamp: String::new(),
971 };
972 hub.broadcast(&event);
973 hub.broadcast_presence("test");
974 }
975
976 #[test]
977 fn num_shards_is_power_of_two() {
978 // Power-of-two shard count ensures even distribution with modulo.
979 assert!(
980 NUM_SHARDS.is_power_of_two(),
981 "NUM_SHARDS ({NUM_SHARDS}) must be a power of two for even distribution"
982 );
983 }
984
985 #[test]
986 fn crdt_subscriptions_subscribe_dedups() {
987 let subs = CrdtSubscriptions::default();
988 subs.subscribe(1, "Channel", "abc");
989 subs.subscribe(1, "Channel", "abc");
990 assert_eq!(subs.subscribers("Channel", "abc"), vec![1]);
991 assert_eq!(subs.total_subscriptions(), 1);
992 }
993
994 #[test]
995 fn crdt_subscriptions_returns_all_subscribers() {
996 let subs = CrdtSubscriptions::default();
997 subs.subscribe(1, "Channel", "abc");
998 subs.subscribe(2, "Channel", "abc");
999 subs.subscribe(3, "Channel", "abc");
1000 let mut ids = subs.subscribers("Channel", "abc");
1001 ids.sort();
1002 assert_eq!(ids, vec![1, 2, 3]);
1003 }
1004
1005 #[test]
1006 fn crdt_subscriptions_unsubscribe_cleans_empty_rows() {
1007 let subs = CrdtSubscriptions::default();
1008 subs.subscribe(1, "Channel", "abc");
1009 subs.unsubscribe(1, "Channel", "abc");
1010 assert!(subs.subscribers("Channel", "abc").is_empty());
1011 // total should drop the empty by_row entry, not leave a 0-set
1012 // around forever.
1013 assert_eq!(subs.total_subscriptions(), 0);
1014 }
1015
1016 #[test]
1017 fn crdt_subscriptions_unsubscribe_all_drops_every_row() {
1018 let subs = CrdtSubscriptions::default();
1019 subs.subscribe(1, "Channel", "a");
1020 subs.subscribe(1, "Channel", "b");
1021 subs.subscribe(1, "Message", "m1");
1022 subs.subscribe(2, "Channel", "a"); // someone else, must survive
1023 subs.unsubscribe_all(1);
1024 assert!(subs.subscribers("Channel", "b").is_empty());
1025 assert!(subs.subscribers("Message", "m1").is_empty());
1026 // Client 2 is still there.
1027 assert_eq!(subs.subscribers("Channel", "a"), vec![2]);
1028 }
1029
1030 #[test]
1031 fn crdt_subscriptions_unsubscribe_unknown_client_is_noop() {
1032 let subs = CrdtSubscriptions::default();
1033 subs.unsubscribe(99, "Channel", "abc");
1034 subs.unsubscribe_all(99);
1035 assert_eq!(subs.total_subscriptions(), 0);
1036 }
1037
1038 #[test]
1039 fn crdt_subscriptions_concurrent_subscribe_and_unsubscribe() {
1040 // Hammer subscribe + unsubscribe from many threads to verify
1041 // the single-mutex design keeps by_row and by_client in sync.
1042 // Previous two-mutex version could leave the maps divergent
1043 // under interleaving.
1044 let subs = Arc::new(CrdtSubscriptions::default());
1045 let mut handles = Vec::new();
1046 for client_id in 0..16u64 {
1047 let subs = Arc::clone(&subs);
1048 handles.push(std::thread::spawn(move || {
1049 for i in 0..200 {
1050 let row = format!("row-{i}");
1051 subs.subscribe(client_id, "Channel", &row);
1052 subs.unsubscribe(client_id, "Channel", &row);
1053 }
1054 }));
1055 }
1056 for h in handles {
1057 h.join().unwrap();
1058 }
1059 // Every subscribe paired with an unsubscribe — registry must be
1060 // fully drained.
1061 assert_eq!(subs.total_subscriptions(), 0);
1062 }
1063
1064 #[test]
1065 fn crdt_subscriptions_unsubscribe_all_after_concurrent_subscribes() {
1066 let subs = Arc::new(CrdtSubscriptions::default());
1067 let mut handles = Vec::new();
1068 for client_id in 0..8u64 {
1069 let subs = Arc::clone(&subs);
1070 handles.push(std::thread::spawn(move || {
1071 for i in 0..100 {
1072 let row = format!("row-{i}");
1073 subs.subscribe(client_id, "Channel", &row);
1074 }
1075 }));
1076 }
1077 for h in handles {
1078 h.join().unwrap();
1079 }
1080 // Now wipe each client and confirm no orphan rows remain.
1081 for client_id in 0..8u64 {
1082 subs.unsubscribe_all(client_id);
1083 }
1084 assert_eq!(subs.total_subscriptions(), 0);
1085 }
1086
1087 #[test]
1088 fn shard_assignment_distributes_evenly() {
1089 // Verify that sequential IDs spread across all shards.
1090 let mut counts = vec![0usize; NUM_SHARDS];
1091 for id in 0..(NUM_SHARDS as u64 * 100) {
1092 counts[(id as usize) % NUM_SHARDS] += 1;
1093 }
1094 // Every shard should get exactly 100 clients.
1095 for (i, count) in counts.iter().enumerate() {
1096 assert_eq!(*count, 100, "Shard {i} got {count} clients, expected 100");
1097 }
1098 }
1099}