pylon_runtime/ws.rs
1use std::collections::{HashMap, HashSet};
2use std::net::{TcpListener, TcpStream};
3use std::sync::mpsc;
4use std::sync::{Arc, Mutex};
5use std::thread;
6use std::time::Duration;
7
8use pylon_auth::SessionStore;
9use pylon_sync::ChangeEvent;
10use tungstenite::handshake::server::{ErrorResponse, Request, Response};
11use tungstenite::{accept_hdr, Message, WebSocket};
12
13use crate::ip_limit::IpConnCounter;
14
15// ---------------------------------------------------------------------------
16// CRDT subscription manager
17//
18// Per-client subscriptions to (entity, row_id) pairs. Lets the binary CRDT
19// broadcast filter to only the clients that asked, instead of fanning out
20// every CRDT write to every connected WS client.
21//
22// Two reverse maps so both hot paths are O(subscribers per row) and
23// O(rows per client): the broadcast looks up subscribers by row, the
24// disconnect cleanup walks rows by client.
25//
26// Subscriptions are explicit and ephemeral — a client subscribes when
27// useLoroDoc(entity, id) mounts, unsubscribes on unmount or disconnect.
28// Server doesn't persist subscriptions across reconnects; the client
29// re-sends them.
30// ---------------------------------------------------------------------------
31
32#[derive(Default)]
33struct SubsState {
34 /// (entity, row_id) → set of client_ids subscribed to that row.
35 by_row: HashMap<(String, String), HashSet<u64>>,
36 /// client_id → set of (entity, row_id) it subscribes to.
37 /// Inverted to make disconnect cleanup O(rows per client) instead of
38 /// O(total rows in by_row).
39 by_client: HashMap<u64, HashSet<(String, String)>>,
40}
41
42pub struct CrdtSubscriptions {
43 /// Single mutex covers both reverse maps so any pair of operations
44 /// (subscribe + unsubscribe across threads, broadcast + disconnect
45 /// cleanup) sees a consistent view. Two separate mutexes would let
46 /// `subscribe` land in `by_row` while a concurrent `unsubscribe_all`
47 /// snapshots `by_client` mid-update, leaving the maps divergent.
48 state: Mutex<SubsState>,
49}
50
51impl Default for CrdtSubscriptions {
52 fn default() -> Self {
53 Self {
54 state: Mutex::new(SubsState::default()),
55 }
56 }
57}
58
59impl CrdtSubscriptions {
60 pub fn new() -> Arc<Self> {
61 Arc::new(Self::default())
62 }
63
64 /// Register a client's interest in a row. Idempotent — re-subscribing
65 /// the same client to the same row is a no-op (HashSet semantics).
66 pub fn subscribe(&self, client_id: u64, entity: &str, row_id: &str) {
67 let key = (entity.to_string(), row_id.to_string());
68 let mut state = self.state.lock().unwrap();
69 state
70 .by_row
71 .entry(key.clone())
72 .or_default()
73 .insert(client_id);
74 state.by_client.entry(client_id).or_default().insert(key);
75 }
76
77 /// Drop one subscription. Cleans up empty maps so the working set
78 /// stays bounded — long-running connections that subscribe and
79 /// unsubscribe to many rows over their lifetime don't accumulate
80 /// orphan empty entries.
81 pub fn unsubscribe(&self, client_id: u64, entity: &str, row_id: &str) {
82 let key = (entity.to_string(), row_id.to_string());
83 let mut state = self.state.lock().unwrap();
84 if let Some(set) = state.by_row.get_mut(&key) {
85 set.remove(&client_id);
86 if set.is_empty() {
87 state.by_row.remove(&key);
88 }
89 }
90 if let Some(set) = state.by_client.get_mut(&client_id) {
91 set.remove(&key);
92 if set.is_empty() {
93 state.by_client.remove(&client_id);
94 }
95 }
96 }
97
98 /// Drop every subscription for a client (called on WS disconnect or
99 /// when a broadcast send fails for that client). Atomic over the
100 /// whole client's subscription set — broadcast snapshots taken
101 /// concurrently see the client either fully present or fully gone.
102 pub fn unsubscribe_all(&self, client_id: u64) {
103 let mut state = self.state.lock().unwrap();
104 let rows: Vec<(String, String)> = state
105 .by_client
106 .remove(&client_id)
107 .map(|set| set.into_iter().collect())
108 .unwrap_or_default();
109 for key in rows {
110 if let Some(set) = state.by_row.get_mut(&key) {
111 set.remove(&client_id);
112 if set.is_empty() {
113 state.by_row.remove(&key);
114 }
115 }
116 }
117 }
118
119 /// Snapshot the subscriber set for a row. Returns an owned `Vec`
120 /// rather than a guard so the broadcast hot path doesn't hold the
121 /// mutex during the per-client send loop.
122 pub fn subscribers(&self, entity: &str, row_id: &str) -> Vec<u64> {
123 let key = (entity.to_string(), row_id.to_string());
124 let state = self.state.lock().unwrap();
125 state
126 .by_row
127 .get(&key)
128 .map(|set| set.iter().copied().collect())
129 .unwrap_or_default()
130 }
131
132 /// Diagnostic: total number of (client, row) pairs.
133 pub fn total_subscriptions(&self) -> usize {
134 self.state
135 .lock()
136 .unwrap()
137 .by_row
138 .values()
139 .map(|s| s.len())
140 .sum()
141 }
142}
143
144/// Number of shards for distributing WebSocket clients.
145/// Must be a power of two for even modulo distribution.
146const NUM_SHARDS: usize = 16;
147
148/// Maximum number of outbound messages queued per shard. Once the broadcast
149/// worker thread falls this many behind, the OLDEST queued message is
150/// dropped to make room for the new one. That means slow subscribers can
151/// miss messages — but the alternative (unbounded queue) was OOM when a
152/// single stuck client blocked its shard worker.
153///
154/// Callers that need exact delivery should layer their own retry on top
155/// (the change-log cursor protocol already does this for sync).
156const BROADCAST_QUEUE_DEPTH: usize = 1024;
157
158/// Read timeout on each WebSocket read. Kept low so the mutex guarding the
159/// socket is released frequently, letting the broadcaster get its turn even
160/// if the client never sends anything. Previously this was 120s, which meant
161/// one quiet client could wedge the shard's writer for up to two minutes.
162const WS_READ_TIMEOUT: Duration = Duration::from_millis(200);
163
164/// One entry per connected client. The socket lives behind its OWN
165/// `Mutex`, not a shard-wide one, so the reader thread's blocking
166/// `socket.read()` doesn't hold a lock that covers every client in the
167/// same shard. The broadcaster iterates the client map (outer lock is
168/// brief — O(count of clients in shard)), then grabs each client's
169/// individual mutex to do the `socket.send`. Contention is now per-
170/// client instead of per-shard.
171type ClientSocket = Arc<Mutex<WebSocket<TcpStream>>>;
172
173/// A single shard holding a subset of WebSocket clients.
174///
175/// The outer `Mutex<HashMap>` is held only for insert/remove and while
176/// enumerating client handles — never across I/O.
177struct Shard {
178 clients: Mutex<HashMap<u64, ClientSocket>>,
179}
180
181impl Shard {
182 fn new() -> Self {
183 Self {
184 clients: Mutex::new(HashMap::new()),
185 }
186 }
187
188 fn add(&self, id: u64, ws: WebSocket<TcpStream>) -> ClientSocket {
189 let handle = Arc::new(Mutex::new(ws));
190 self.clients.lock().unwrap().insert(id, Arc::clone(&handle));
191 handle
192 }
193
194 fn remove(&self, id: u64) {
195 self.clients.lock().unwrap().remove(&id);
196 }
197
198 /// Send a message to all clients in this shard.
199 ///
200 /// Snapshot the client handles under the shard lock, drop the shard
201 /// lock, then contend only with per-client mutexes to do the writes.
202 /// This is what lets a reader thread hold its client's mutex for a
203 /// socket.read() without stalling broadcasts for the whole shard.
204 ///
205 /// `msg` is `Arc<str>` rather than `&str` so the caller can serialize
206 /// the JSON exactly once and share the same allocation across all
207 /// 16 shards. Per-client `Message::Text` still allocates an owned
208 /// String (tungstenite 0.24 requires it), but the broadcast no
209 /// longer pays N copies of the JSON across shard channels.
210 fn broadcast(&self, msg: &Arc<str>) {
211 let handles: Vec<(u64, ClientSocket)> = {
212 let clients = self.clients.lock().unwrap();
213 clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
214 };
215 let mut dead: Vec<u64> = Vec::new();
216 for (id, handle) in handles {
217 // `try_lock` would skip clients whose reader is currently
218 // blocked in read(); we prefer `lock()` here so the occasional
219 // broadcaster wait (bounded by the 200ms read timeout) doesn't
220 // drop the message for that client.
221 let mut guard = match handle.lock() {
222 Ok(g) => g,
223 Err(poisoned) => poisoned.into_inner(),
224 };
225 // Owned String per send is the tungstenite 0.24 contract.
226 // The clone here copies the string contents; sharing the
227 // raw bytes via Utf8Bytes would be the next-level
228 // optimization but requires a tungstenite version bump.
229 if guard.send(Message::Text((**msg).to_string())).is_err() {
230 dead.push(id);
231 }
232 }
233 if !dead.is_empty() {
234 let mut clients = self.clients.lock().unwrap();
235 for id in &dead {
236 clients.remove(id);
237 }
238 }
239 }
240
241 /// Send a binary frame to a SPECIFIC subset of this shard's clients.
242 /// Used by the per-client subscription path — `WsHub::broadcast_binary_to`
243 /// computes which ids each shard owns and calls this with just those.
244 ///
245 /// Same per-client lock pattern as `broadcast` / `broadcast_binary`,
246 /// just filtered up front instead of iterating the whole shard.
247 ///
248 /// Returns the list of client ids whose send failed so the caller
249 /// can also clear those ids from the CRDT subscription registry —
250 /// without that step a dead client's subscription entries linger
251 /// until the reader thread notices the EOF and runs unsubscribe_all,
252 /// which can take up to one read-timeout (200ms) longer than the
253 /// send-side death detection.
254 fn send_binary_to(&self, ids: &[u64], msg: &Arc<[u8]>) -> Vec<u64> {
255 let handles: Vec<(u64, ClientSocket)> = {
256 let clients = self.clients.lock().unwrap();
257 ids.iter()
258 .filter_map(|id| clients.get(id).map(|h| (*id, Arc::clone(h))))
259 .collect()
260 };
261 let mut dead: Vec<u64> = Vec::new();
262 for (id, handle) in handles {
263 let mut guard = match handle.lock() {
264 Ok(g) => g,
265 Err(poisoned) => poisoned.into_inner(),
266 };
267 if guard.send(Message::Binary(msg.to_vec())).is_err() {
268 dead.push(id);
269 }
270 }
271 if !dead.is_empty() {
272 let mut clients = self.clients.lock().unwrap();
273 for id in &dead {
274 clients.remove(id);
275 }
276 }
277 dead
278 }
279
280 /// Binary fanout for CRDT updates. Same per-client lock pattern as
281 /// `broadcast` above; the only difference is `Message::Binary` and
282 /// the payload is `Arc<[u8]>` so a single Loro snapshot allocates
283 /// once and the per-client send pays a refcount bump + the
284 /// tungstenite-required Vec clone.
285 fn broadcast_binary(&self, msg: &Arc<[u8]>) {
286 let handles: Vec<(u64, ClientSocket)> = {
287 let clients = self.clients.lock().unwrap();
288 clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
289 };
290 let mut dead: Vec<u64> = Vec::new();
291 for (id, handle) in handles {
292 let mut guard = match handle.lock() {
293 Ok(g) => g,
294 Err(poisoned) => poisoned.into_inner(),
295 };
296 if guard.send(Message::Binary(msg.to_vec())).is_err() {
297 dead.push(id);
298 }
299 }
300 if !dead.is_empty() {
301 let mut clients = self.clients.lock().unwrap();
302 for id in &dead {
303 clients.remove(id);
304 }
305 }
306 }
307
308 fn count(&self) -> usize {
309 self.clients.lock().unwrap().len()
310 }
311}
312
313/// High-performance WebSocket broadcast hub with sharded client storage.
314///
315/// Supports 10k+ concurrent connections with bounded thread count.
316/// Uses NUM_SHARDS (16) shards to reduce lock contention.
317///
318/// Architecture:
319/// - Client connections are assigned to shards via round-robin (id % NUM_SHARDS).
320/// - Each shard has a dedicated broadcast worker thread that consumes from a channel.
321/// - Broadcast calls are non-blocking for the caller: they push to each shard's channel
322/// and return immediately.
323/// - Read-side threads use 64KB stacks (vs 2-8MB default) to keep memory bounded.
324/// - Total thread count: NUM_SHARDS broadcast workers + 1 per connected client (with
325/// minimal stack), plus the accept thread.
326pub struct WsHub {
327 shards: Vec<Arc<Shard>>,
328 next_id: Mutex<u64>,
329 /// Bounded-capacity senders for each shard's broadcast worker. When
330 /// a send would block because the queue is full, `broadcast_raw` drains
331 /// the oldest queued messages so new ones aren't lost to a stuck worker.
332 ///
333 /// Carries `Arc<str>` so a single broadcast event allocates the JSON
334 /// once and the 16 shard sends are cheap refcount bumps. Was a 16×
335 /// String clone hotspot under high write rates with thousands of
336 /// subscribers per shard.
337 broadcast_txs: Vec<mpsc::SyncSender<Arc<str>>>,
338 /// Matching receivers are held by each worker thread and also exposed
339 /// here so the "drop oldest" fallback can drain them on full. Keeping
340 /// the receiver handle alongside the sender is only safe because mpsc
341 /// lets multiple clones share a queue — here we only consume via the
342 /// worker, and the sender-side uses `try_send` + drain retry.
343 #[allow(dead_code)]
344 queue_depth: usize,
345 /// Per-client CRDT subscriptions. Reader threads register `(entity,
346 /// row_id)` pairs as the client mounts/unmounts useLoroDoc hooks;
347 /// the binary CRDT broadcast path uses `subscribers()` to filter the
348 /// fanout. Wrapped in Arc so the notifier (which holds `Arc<WsHub>`)
349 /// can read the subscriber set without taking an extra lock layer.
350 subscriptions: Arc<CrdtSubscriptions>,
351}
352
353impl WsHub {
354 pub fn new() -> Arc<Self> {
355 let mut shards = Vec::with_capacity(NUM_SHARDS);
356 let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
357
358 for i in 0..NUM_SHARDS {
359 let shard = Arc::new(Shard::new());
360 // Bounded queue — if a broadcast worker stalls, `try_send` fails
361 // with Full and `broadcast_raw` drops the oldest to make room.
362 let (tx, rx) = mpsc::sync_channel::<Arc<str>>(BROADCAST_QUEUE_DEPTH);
363
364 let shard_clone = Arc::clone(&shard);
365 thread::Builder::new()
366 .name(format!("ws-broadcast-{i}"))
367 .spawn(move || {
368 while let Ok(msg) = rx.recv() {
369 shard_clone.broadcast(&msg);
370 }
371 })
372 .expect("Failed to spawn broadcast worker");
373
374 shards.push(shard);
375 broadcast_txs.push(tx);
376 }
377
378 Arc::new(Self {
379 shards,
380 next_id: Mutex::new(0),
381 broadcast_txs,
382 queue_depth: BROADCAST_QUEUE_DEPTH,
383 subscriptions: CrdtSubscriptions::new(),
384 })
385 }
386
387 /// Access the per-client CRDT subscription registry. The notifier
388 /// looks up subscribers via `subscriptions().subscribers(entity, row)`
389 /// and feeds them to `broadcast_binary_to`.
390 pub fn subscriptions(&self) -> &Arc<CrdtSubscriptions> {
391 &self.subscriptions
392 }
393
394 /// Broadcast a change event to ALL connected clients across all shards.
395 /// Non-blocking: pushes to each shard's channel and returns immediately.
396 ///
397 /// Serializes the event JSON exactly once into an `Arc<str>` and
398 /// shares it across the 16 shard senders. Each shard's worker
399 /// thread receives the same Arc and pays only a refcount bump.
400 pub fn broadcast(&self, event: &ChangeEvent) {
401 let json = match serde_json::to_string(event) {
402 Ok(j) => j,
403 Err(_) => return,
404 };
405 let shared: Arc<str> = Arc::from(json.into_boxed_str());
406 self.broadcast_shared(shared);
407 }
408
409 /// Broadcast a raw string message to all clients (used for presence updates).
410 pub fn broadcast_presence(&self, msg: &str) {
411 let shared: Arc<str> = Arc::from(msg.to_string().into_boxed_str());
412 self.broadcast_shared(shared);
413 }
414
415 /// Broadcast a binary frame to every connected client across all
416 /// shards. Used for CRDT updates (see `pylon_router::encode_crdt_frame`
417 /// for the wire shape). The bytes are wrapped in an `Arc` so each
418 /// shard's per-client fanout shares one allocation; the per-send
419 /// `to_vec()` cost is the tungstenite 0.24 contract.
420 ///
421 /// Synchronous fanout — iterates shards directly rather than going
422 /// through the per-shard mpsc workers. CRDT writes happen at most
423 /// once per logical mutation so the throughput shape is "occasional
424 /// burst" not "every keystroke", and direct fanout avoids growing a
425 /// second per-shard channel (Arc<[u8]> can't share the Arc<str>
426 /// channel without an enum, which costs more than the bypass).
427 pub fn broadcast_binary(&self, bytes: Vec<u8>) {
428 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
429 for shard in &self.shards {
430 shard.broadcast_binary(&shared);
431 }
432 }
433
434 /// Send a binary frame to a specific subset of client IDs only.
435 /// Used by the CRDT broadcast path to fan out only to clients
436 /// subscribed to the row that just changed (instead of every
437 /// connected client). Routes each id to its owning shard via
438 /// `id % NUM_SHARDS`.
439 ///
440 /// `client_ids` typically comes from `CrdtSubscriptions::subscribers`.
441 /// An empty list is a no-op — the row had no subscribers, so the
442 /// CRDT write is durable on the server but no client sees the
443 /// binary frame (they'll learn about the change via the JSON
444 /// change-event broadcast which always fires).
445 pub fn broadcast_binary_to(&self, client_ids: &[u64], bytes: Vec<u8>) {
446 if client_ids.is_empty() {
447 return;
448 }
449 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
450 // Group ids by shard so each shard's per-client lock is only
451 // grabbed once even if many subscribers landed in the same one.
452 let mut by_shard: Vec<Vec<u64>> = (0..NUM_SHARDS).map(|_| Vec::new()).collect();
453 for id in client_ids {
454 by_shard[(*id as usize) % NUM_SHARDS].push(*id);
455 }
456 for (idx, ids) in by_shard.iter().enumerate() {
457 if ids.is_empty() {
458 continue;
459 }
460 for dead_id in self.shards[idx].send_binary_to(ids, &shared) {
461 // Drop the dead client's subscription entries too —
462 // otherwise they leak until the reader thread's read
463 // timeout fires and runs unsubscribe_all on its own,
464 // and a future broadcast might re-attempt the dead id.
465 self.subscriptions.unsubscribe_all(dead_id);
466 }
467 }
468 }
469
470 /// Send a binary frame to a single client by id. Used by the
471 /// subscribe path: when a client subscribes to a row, the server
472 /// immediately ships the current snapshot so the new subscriber
473 /// has the up-to-date state without waiting for the next write.
474 pub fn send_binary_to_one(&self, client_id: u64, bytes: Vec<u8>) {
475 let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
476 let shard_idx = (client_id as usize) % NUM_SHARDS;
477 for dead_id in self.shards[shard_idx].send_binary_to(&[client_id], &shared) {
478 self.subscriptions.unsubscribe_all(dead_id);
479 }
480 }
481
482 /// Internal: fan out a single shared message to every shard worker.
483 ///
484 /// Uses `try_send`; on full we log once (per call) and drop the message
485 /// for that shard. Previously the channel was unbounded, so a stuck
486 /// worker thread would grow memory until OOM. The new bounded queue
487 /// means a slow/stuck subscriber at worst loses broadcast events —
488 /// correctness for critical data still comes through the change-log
489 /// cursor on a reconnect.
490 fn broadcast_shared(&self, msg: Arc<str>) {
491 for tx in &self.broadcast_txs {
492 match tx.try_send(Arc::clone(&msg)) {
493 Ok(()) => {}
494 Err(mpsc::TrySendError::Full(_)) => {
495 tracing::warn!("[ws] broadcast queue full — dropping event for one shard");
496 }
497 Err(mpsc::TrySendError::Disconnected(_)) => {
498 // Worker exited (shutdown). Silent.
499 }
500 }
501 }
502 }
503
504 /// Assign a client to a shard via round-robin and register it.
505 /// Returns `(id, socket_handle)` — the caller keeps the handle and uses
506 /// it for reads; the shard also keeps an Arc clone for broadcasts.
507 fn add_client(&self, ws: WebSocket<TcpStream>) -> (u64, ClientSocket) {
508 let mut next_id = self.next_id.lock().unwrap();
509 let id = *next_id;
510 *next_id += 1;
511 let shard_idx = (id as usize) % NUM_SHARDS;
512 let handle = self.shards[shard_idx].add(id, ws);
513 (id, handle)
514 }
515
516 fn remove_client(&self, id: u64) {
517 let shard_idx = (id as usize) % NUM_SHARDS;
518 self.shards[shard_idx].remove(id);
519 }
520
521 /// Total number of connected clients across all shards.
522 pub fn client_count(&self) -> usize {
523 self.shards.iter().map(|s| s.count()).sum()
524 }
525}
526
527/// Snapshot fetcher: given the caller's auth context + `(entity,
528/// row_id)`, return the encoded binary CRDT frame for the row's
529/// current state, or `None` if either the caller can't read the row
530/// (read policy denies) or the row has no snapshot (uninitialized
531/// CRDT or non-CRDT entity).
532///
533/// Auth context is passed in (rather than checked at the WS layer)
534/// because the policy engine + DataStore handles live in the runtime
535/// crate. Without this check an authenticated client could subscribe
536/// to any `(entity, row_id)` and receive every binary CRDT frame
537/// even for rows their query policy would reject — a silent read-
538/// policy bypass.
539///
540/// Wrapped in an Arc<dyn Fn> so the runtime can build it once, capturing
541/// the LoroStore + PolicyEngine handles, and hand the same closure to
542/// every accepted connection.
543pub type SnapshotFetcher =
544 Arc<dyn Fn(&pylon_auth::AuthContext, &str, &str) -> Option<Vec<u8>> + Send + Sync>;
545
546/// Start the WebSocket server on the given port.
547///
548/// The accept loop runs on the calling thread (blocking). Each accepted
549/// connection spawns a lightweight reader thread with a 64KB stack.
550/// Broadcast writes are handled by the shard worker threads, not by
551/// per-client threads.
552///
553/// The session store is required: every connection must present a valid
554/// bearer token (Authorization header or `bearer.<token>` subprotocol —
555/// browsers can't set WS headers directly). Previously the notifier hub
556/// accepted any connection and streamed every ChangeEvent/presence event
557/// to it, which was a silent read-policy bypass.
558///
559/// `snapshot_fetcher` is optional — when present, the reader will ship
560/// the current CRDT snapshot to the subscribing client immediately on
561/// `crdt-subscribe`, so the new tab sees the latest converged state
562/// without waiting for the next write. When absent, subscribe is still
563/// recorded but the catch-up frame is skipped.
564pub fn start_ws_server(
565 hub: Arc<WsHub>,
566 sessions: Arc<SessionStore>,
567 port: u16,
568 snapshot_fetcher: Option<SnapshotFetcher>,
569) {
570 let addr = format!("0.0.0.0:{port}");
571 let listener = match TcpListener::bind(&addr) {
572 Ok(l) => l,
573 Err(e) => {
574 tracing::warn!("[ws] Failed to bind on {addr}: {e}");
575 return;
576 }
577 };
578
579 tracing::warn!(
580 "[ws] WebSocket server listening on ws://localhost:{port} (sharded, {NUM_SHARDS} shards)"
581 );
582
583 let ip_counter = Arc::new(IpConnCounter::default());
584
585 for stream in listener.incoming() {
586 let stream = match stream {
587 Ok(s) => s,
588 Err(_) => continue,
589 };
590
591 // Per-IP connection cap: reject BEFORE the handshake so a cheap
592 // connect storm doesn't force us through tungstenite's HTTP parse
593 // and the session-resolve round trip. The guard is dropped when
594 // the reader thread exits (or fails to start), freeing the slot.
595 let ip = match stream.peer_addr() {
596 Ok(addr) => addr.ip(),
597 Err(_) => continue,
598 };
599 let guard = match ip_counter.acquire(ip) {
600 Some(g) => g,
601 None => {
602 // Ignore: let the client re-try after an existing connection
603 // closes. Previously an IP could open unbounded connections
604 // and each one spawned a thread + held a per-client mutex.
605 continue;
606 }
607 };
608
609 let hub = Arc::clone(&hub);
610 let sessions = Arc::clone(&sessions);
611 let fetcher = snapshot_fetcher.clone();
612 // Spawn a reader thread per client with a small stack.
613 // 64KB stack * 10k connections = ~640MB, vs 2-8MB default * 10k = 20-80GB.
614 let spawn_result = thread::Builder::new()
615 .name("ws-client".into())
616 .stack_size(64 * 1024)
617 .spawn(move || {
618 // Holding `guard` for the life of the connection thread is
619 // what makes the decrement-on-disconnect contract work. Not
620 // `let _ = guard;` — that drops immediately.
621 let _conn_slot = guard;
622 handle_ws_connection(hub, sessions, stream, fetcher);
623 });
624 if spawn_result.is_err() {
625 // Thread creation failed — guard is already dropped here, slot
626 // returned. We deliberately don't call `continue` before the
627 // spawn: we've paid the acquire cost and want to avoid leaking
628 // a slot under transient thread-limit pressure.
629 }
630 }
631}
632
633/// Handle a single WebSocket client connection.
634///
635/// Sets a read timeout to prevent zombie threads on dead connections.
636/// Handles ping/pong for keepalive, presence/topic message relay,
637/// and clean disconnect with presence broadcast.
638fn handle_ws_connection(
639 hub: Arc<WsHub>,
640 sessions: Arc<SessionStore>,
641 stream: TcpStream,
642 snapshot_fetcher: Option<SnapshotFetcher>,
643) {
644 // Short read timeout bounds how long the PER-CLIENT mutex is held
645 // while this thread is blocked in socket.read(). Each client now has
646 // its own mutex (not a shard-wide one), so a quiet client only stalls
647 // the broadcaster when it's broadcasting to THAT specific client —
648 // other clients in the same shard proceed without contention.
649 stream.set_read_timeout(Some(WS_READ_TIMEOUT)).ok();
650 // Also cap write time. A stuck kernel send (slow client, full send
651 // buffer, dropped packets) would otherwise stall the shard's
652 // broadcast worker holding this client's mutex — backpressure
653 // becomes head-of-line blocking for everyone. Capped at 5s; slow
654 // clients get disconnected rather than stalling the hub.
655 stream.set_write_timeout(Some(WS_READ_TIMEOUT)).ok();
656
657 // Extract the bearer token from the handshake, preferring the
658 // Authorization header (native clients) and falling back to the
659 // `bearer.<token>` WebSocket subprotocol (browsers). We only learn
660 // whether the token is valid AFTER accept_hdr completes, since the
661 // header callback must return synchronously with a Response.
662 let token_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
663 let slot_for_cb = Arc::clone(&token_slot);
664 let ws = match accept_hdr(
665 stream,
666 move |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
667 let mut chosen_protocol: Option<String> = None;
668 let mut auth: Option<String> = None;
669 for (name, value) in req.headers() {
670 let lower = name.as_str().to_ascii_lowercase();
671 if lower == "authorization" {
672 if let Ok(v) = value.to_str() {
673 if let Some(tok) = v.strip_prefix("Bearer ") {
674 auth = Some(tok.to_string());
675 }
676 }
677 } else if lower == "sec-websocket-protocol" {
678 if let Ok(v) = value.to_str() {
679 for proto in v.split(',').map(str::trim) {
680 if let Some(encoded) = proto.strip_prefix("bearer.") {
681 if let Some(decoded) = percent_decode_token(encoded) {
682 auth = auth.or(Some(decoded));
683 chosen_protocol = Some(proto.to_string());
684 break;
685 }
686 }
687 }
688 }
689 }
690 }
691 // RFC 6455 §11.3.4 — echo the chosen subprotocol in the response or
692 // browsers will refuse the connection.
693 if let Some(chosen) = chosen_protocol {
694 if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
695 resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
696 }
697 }
698 *slot_for_cb.lock().unwrap() = auth;
699 Ok(resp)
700 },
701 ) {
702 Ok(ws) => ws,
703 Err(_) => return,
704 };
705
706 // Reject unauthenticated or invalid-token handshakes AFTER accept —
707 // tungstenite's handshake callback can't easily return a 401 without
708 // a custom error response, and we already have the socket open for
709 // a clean close frame.
710 let token = token_slot.lock().unwrap().clone();
711 let auth_ctx = sessions.resolve(token.as_deref());
712 if auth_ctx.user_id.is_none() && !auth_ctx.is_admin {
713 let mut ws = ws;
714 let _ = ws.close(Some(tungstenite::protocol::CloseFrame {
715 code: tungstenite::protocol::frame::coding::CloseCode::Policy,
716 reason: "unauthorized: bearer token required".into(),
717 }));
718 return;
719 }
720
721 let (client_id, socket_handle) = hub.add_client(ws);
722
723 loop {
724 // Lock this client's socket mutex only for the duration of the
725 // read. With a 5s read timeout, broadcasters waiting to send to
726 // THIS client wait at most 5s. Other clients are never blocked
727 // by this lock — they have their own.
728 let msg = {
729 let mut guard = match socket_handle.lock() {
730 Ok(g) => g,
731 Err(poisoned) => poisoned.into_inner(),
732 };
733 guard.read()
734 };
735
736 match msg {
737 Ok(Message::Text(text)) => {
738 // Parse once and dispatch on the type field instead of
739 // matching prefix bytes — that approach silently dropped
740 // valid JSON with whitespace, key reordering, or any
741 // other formatting variation. Non-object / no-`type`
742 // messages are ignored.
743 let parsed: serde_json::Value = match serde_json::from_str(&text) {
744 Ok(v) => v,
745 Err(_) => continue,
746 };
747 let kind = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
748 match kind {
749 "presence" | "topic" => {
750 // Stamp the authenticated sender server-side,
751 // overriding any client-provided `from`. Without
752 // this, any client could spoof presence/topic
753 // events as another user — every connected
754 // client would see a forged "alice typed…"
755 // message attributed to alice.
756 let mut stamped = parsed.clone();
757 if let Some(obj) = stamped.as_object_mut() {
758 let from = auth_ctx
759 .user_id
760 .clone()
761 .unwrap_or_else(|| "admin".to_string());
762 obj.insert("from".into(), serde_json::Value::String(from));
763 }
764 hub.broadcast_presence(&stamped.to_string());
765 }
766 "crdt-subscribe" | "crdt-unsubscribe" => handle_crdt_control(
767 &hub,
768 client_id,
769 &auth_ctx,
770 kind,
771 &parsed,
772 snapshot_fetcher.as_ref(),
773 ),
774 _ => {}
775 }
776 }
777 Ok(Message::Ping(data)) => {
778 // Respond with pong to keep the connection alive.
779 if let Ok(mut guard) = socket_handle.lock() {
780 let _ = guard.send(Message::Pong(data));
781 }
782 }
783 Ok(Message::Close(_)) => {
784 // Drop every CRDT subscription this client held BEFORE
785 // remove_client so the broadcast path can never look up
786 // a stale client_id between the two ops.
787 hub.subscriptions.unsubscribe_all(client_id);
788 hub.remove_client(client_id);
789 let disconnect = serde_json::json!({
790 "type": "presence",
791 "event": "disconnect",
792 "clientId": client_id,
793 });
794 hub.broadcast_presence(&disconnect.to_string());
795 break;
796 }
797 Err(tungstenite::Error::Io(io_err))
798 if io_err.kind() == std::io::ErrorKind::WouldBlock
799 || io_err.kind() == std::io::ErrorKind::TimedOut =>
800 {
801 // Read timed out — this is EXPECTED with the short
802 // timeout. In theory the mutex is released between
803 // iterations, but `std::sync::Mutex` is not fair: a tight
804 // loop of lock→read→unlock→lock starves the broadcaster
805 // that's been waiting on the same mutex. Explicitly sleep
806 // for a tick so the broadcaster gets scheduled. 1ms is
807 // long enough to hand off, short enough that client→server
808 // latency stays sub-5ms.
809 std::thread::sleep(std::time::Duration::from_millis(1));
810 continue;
811 }
812 Err(_) => {
813 hub.subscriptions.unsubscribe_all(client_id);
814 hub.remove_client(client_id);
815 let disconnect = serde_json::json!({
816 "type": "presence",
817 "event": "disconnect",
818 "clientId": client_id,
819 });
820 hub.broadcast_presence(&disconnect.to_string());
821 break;
822 }
823 _ => {}
824 }
825 }
826}
827
828/// Apply a parsed `crdt-subscribe` / `crdt-unsubscribe` control
829/// message. Both messages have the shape:
830///
831/// { "type": "crdt-subscribe", "entity": "<E>", "rowId": "<id>" }
832/// { "type": "crdt-unsubscribe", "entity": "<E>", "rowId": "<id>" }
833///
834/// On subscribe the snapshot fetcher checks read policy for the
835/// caller's auth context — if the caller can't read the row we
836/// register no subscription and ship nothing back, so a malicious
837/// client can't peek at a row their query policy would block by
838/// just subscribing to its CRDT stream.
839///
840/// Malformed messages are silently dropped — there's no client-visible
841/// ACK protocol, so a typo in the payload would just look like a
842/// row that never receives updates. Logging would invite a noise
843/// channel for misbehaving clients.
844fn handle_crdt_control(
845 hub: &Arc<WsHub>,
846 client_id: u64,
847 auth_ctx: &pylon_auth::AuthContext,
848 kind: &str,
849 parsed: &serde_json::Value,
850 snapshot_fetcher: Option<&SnapshotFetcher>,
851) {
852 let entity = match parsed.get("entity").and_then(|v| v.as_str()) {
853 Some(e) if !e.is_empty() => e,
854 _ => return,
855 };
856 let row_id = match parsed
857 .get("rowId")
858 .or_else(|| parsed.get("row_id"))
859 .and_then(|v| v.as_str())
860 {
861 Some(r) if !r.is_empty() => r,
862 _ => return,
863 };
864
865 match kind {
866 "crdt-subscribe" => {
867 // Authz check happens INSIDE the fetcher (it has access to
868 // the policy engine + DataStore). When a fetcher is wired
869 // and returns None, the caller is either denied or the row
870 // doesn't exist — in both cases we refuse to register the
871 // subscription so a denied caller can't silently hold an
872 // open slot waiting for future writes.
873 //
874 // When no fetcher is wired (test harnesses, future
875 // workers backend without DataStore access) we trust the
876 // caller and register without the auth gate. Production
877 // server.rs always wires one, so this loophole is
878 // unreachable in deployed configurations.
879 let snapshot = snapshot_fetcher.and_then(|f| f(auth_ctx, entity, row_id));
880 let allow_subscribe = snapshot_fetcher.is_none() || snapshot.is_some();
881 if allow_subscribe {
882 hub.subscriptions.subscribe(client_id, entity, row_id);
883 if let Some(bytes) = snapshot {
884 hub.send_binary_to_one(client_id, bytes);
885 }
886 }
887 }
888 "crdt-unsubscribe" => {
889 hub.subscriptions.unsubscribe(client_id, entity, row_id);
890 }
891 _ => {}
892 }
893}
894
895/// Strict percent-decode for the `bearer.<token>` subprotocol. Returns
896/// `None` on any malformed byte rather than silently passing garbage
897/// through to the session store (which would just fail to resolve and
898/// look like a plain unauth attempt).
899fn percent_decode_token(s: &str) -> Option<String> {
900 let bytes = s.as_bytes();
901 let mut out = Vec::with_capacity(bytes.len());
902 let mut i = 0;
903 while i < bytes.len() {
904 match bytes[i] {
905 b'%' => {
906 if i + 2 >= bytes.len() {
907 return None;
908 }
909 let hi = (bytes[i + 1] as char).to_digit(16)?;
910 let lo = (bytes[i + 2] as char).to_digit(16)?;
911 out.push(((hi << 4) | lo) as u8);
912 i += 3;
913 }
914 b'+' => {
915 out.push(b' ');
916 i += 1;
917 }
918 b => {
919 out.push(b);
920 i += 1;
921 }
922 }
923 }
924 String::from_utf8(out).ok()
925}
926
927#[cfg(test)]
928mod tests {
929 use super::*;
930
931 #[test]
932 fn shard_count_starts_at_zero() {
933 let shard = Shard::new();
934 assert_eq!(shard.count(), 0);
935 }
936
937 #[test]
938 fn hub_starts_with_zero_clients() {
939 let hub = WsHub::new();
940 assert_eq!(hub.client_count(), 0);
941 }
942
943 #[test]
944 fn broadcast_to_empty_hub_doesnt_panic() {
945 let hub = WsHub::new();
946 let event = ChangeEvent {
947 seq: 1,
948 entity: "Test".into(),
949 row_id: "1".into(),
950 kind: pylon_sync::ChangeKind::Insert,
951 data: None,
952 timestamp: String::new(),
953 };
954 hub.broadcast(&event);
955 hub.broadcast_presence("test");
956 }
957
958 #[test]
959 fn num_shards_is_power_of_two() {
960 // Power-of-two shard count ensures even distribution with modulo.
961 assert!(
962 NUM_SHARDS.is_power_of_two(),
963 "NUM_SHARDS ({NUM_SHARDS}) must be a power of two for even distribution"
964 );
965 }
966
967 #[test]
968 fn crdt_subscriptions_subscribe_dedups() {
969 let subs = CrdtSubscriptions::default();
970 subs.subscribe(1, "Channel", "abc");
971 subs.subscribe(1, "Channel", "abc");
972 assert_eq!(subs.subscribers("Channel", "abc"), vec![1]);
973 assert_eq!(subs.total_subscriptions(), 1);
974 }
975
976 #[test]
977 fn crdt_subscriptions_returns_all_subscribers() {
978 let subs = CrdtSubscriptions::default();
979 subs.subscribe(1, "Channel", "abc");
980 subs.subscribe(2, "Channel", "abc");
981 subs.subscribe(3, "Channel", "abc");
982 let mut ids = subs.subscribers("Channel", "abc");
983 ids.sort();
984 assert_eq!(ids, vec![1, 2, 3]);
985 }
986
987 #[test]
988 fn crdt_subscriptions_unsubscribe_cleans_empty_rows() {
989 let subs = CrdtSubscriptions::default();
990 subs.subscribe(1, "Channel", "abc");
991 subs.unsubscribe(1, "Channel", "abc");
992 assert!(subs.subscribers("Channel", "abc").is_empty());
993 // total should drop the empty by_row entry, not leave a 0-set
994 // around forever.
995 assert_eq!(subs.total_subscriptions(), 0);
996 }
997
998 #[test]
999 fn crdt_subscriptions_unsubscribe_all_drops_every_row() {
1000 let subs = CrdtSubscriptions::default();
1001 subs.subscribe(1, "Channel", "a");
1002 subs.subscribe(1, "Channel", "b");
1003 subs.subscribe(1, "Message", "m1");
1004 subs.subscribe(2, "Channel", "a"); // someone else, must survive
1005 subs.unsubscribe_all(1);
1006 assert!(subs.subscribers("Channel", "b").is_empty());
1007 assert!(subs.subscribers("Message", "m1").is_empty());
1008 // Client 2 is still there.
1009 assert_eq!(subs.subscribers("Channel", "a"), vec![2]);
1010 }
1011
1012 #[test]
1013 fn crdt_subscriptions_unsubscribe_unknown_client_is_noop() {
1014 let subs = CrdtSubscriptions::default();
1015 subs.unsubscribe(99, "Channel", "abc");
1016 subs.unsubscribe_all(99);
1017 assert_eq!(subs.total_subscriptions(), 0);
1018 }
1019
1020 #[test]
1021 fn crdt_subscriptions_concurrent_subscribe_and_unsubscribe() {
1022 // Hammer subscribe + unsubscribe from many threads to verify
1023 // the single-mutex design keeps by_row and by_client in sync.
1024 // Previous two-mutex version could leave the maps divergent
1025 // under interleaving.
1026 let subs = Arc::new(CrdtSubscriptions::default());
1027 let mut handles = Vec::new();
1028 for client_id in 0..16u64 {
1029 let subs = Arc::clone(&subs);
1030 handles.push(std::thread::spawn(move || {
1031 for i in 0..200 {
1032 let row = format!("row-{i}");
1033 subs.subscribe(client_id, "Channel", &row);
1034 subs.unsubscribe(client_id, "Channel", &row);
1035 }
1036 }));
1037 }
1038 for h in handles {
1039 h.join().unwrap();
1040 }
1041 // Every subscribe paired with an unsubscribe — registry must be
1042 // fully drained.
1043 assert_eq!(subs.total_subscriptions(), 0);
1044 }
1045
1046 #[test]
1047 fn crdt_subscriptions_unsubscribe_all_after_concurrent_subscribes() {
1048 let subs = Arc::new(CrdtSubscriptions::default());
1049 let mut handles = Vec::new();
1050 for client_id in 0..8u64 {
1051 let subs = Arc::clone(&subs);
1052 handles.push(std::thread::spawn(move || {
1053 for i in 0..100 {
1054 let row = format!("row-{i}");
1055 subs.subscribe(client_id, "Channel", &row);
1056 }
1057 }));
1058 }
1059 for h in handles {
1060 h.join().unwrap();
1061 }
1062 // Now wipe each client and confirm no orphan rows remain.
1063 for client_id in 0..8u64 {
1064 subs.unsubscribe_all(client_id);
1065 }
1066 assert_eq!(subs.total_subscriptions(), 0);
1067 }
1068
1069 #[test]
1070 fn shard_assignment_distributes_evenly() {
1071 // Verify that sequential IDs spread across all shards.
1072 let mut counts = vec![0usize; NUM_SHARDS];
1073 for id in 0..(NUM_SHARDS as u64 * 100) {
1074 counts[(id as usize) % NUM_SHARDS] += 1;
1075 }
1076 // Every shard should get exactly 100 clients.
1077 for (i, count) in counts.iter().enumerate() {
1078 assert_eq!(*count, 100, "Shard {i} got {count} clients, expected 100");
1079 }
1080 }
1081}