Skip to main content

pylon_runtime/
ws.rs

1use std::collections::{HashMap, HashSet};
2use std::net::{TcpListener, TcpStream};
3use std::sync::mpsc;
4use std::sync::{Arc, Mutex};
5use std::thread;
6use std::time::Duration;
7
8use pylon_auth::SessionStore;
9use pylon_sync::ChangeEvent;
10use tungstenite::handshake::server::{ErrorResponse, Request, Response};
11use tungstenite::{accept_hdr_with_config, protocol::WebSocketConfig, Message, WebSocket};
12
13use crate::ip_limit::IpConnCounter;
14
15// ---------------------------------------------------------------------------
16// CRDT subscription manager
17//
18// Per-client subscriptions to (entity, row_id) pairs. Lets the binary CRDT
19// broadcast filter to only the clients that asked, instead of fanning out
20// every CRDT write to every connected WS client.
21//
22// Two reverse maps so both hot paths are O(subscribers per row) and
23// O(rows per client): the broadcast looks up subscribers by row, the
24// disconnect cleanup walks rows by client.
25//
26// Subscriptions are explicit and ephemeral — a client subscribes when
27// useLoroDoc(entity, id) mounts, unsubscribes on unmount or disconnect.
28// Server doesn't persist subscriptions across reconnects; the client
29// re-sends them.
30// ---------------------------------------------------------------------------
31
32#[derive(Default)]
33struct SubsState {
34    /// (entity, row_id) → set of client_ids subscribed to that row.
35    by_row: HashMap<(String, String), HashSet<u64>>,
36    /// client_id → set of (entity, row_id) it subscribes to.
37    /// Inverted to make disconnect cleanup O(rows per client) instead of
38    /// O(total rows in by_row).
39    by_client: HashMap<u64, HashSet<(String, String)>>,
40}
41
42pub struct CrdtSubscriptions {
43    /// Single mutex covers both reverse maps so any pair of operations
44    /// (subscribe + unsubscribe across threads, broadcast + disconnect
45    /// cleanup) sees a consistent view. Two separate mutexes would let
46    /// `subscribe` land in `by_row` while a concurrent `unsubscribe_all`
47    /// snapshots `by_client` mid-update, leaving the maps divergent.
48    state: Mutex<SubsState>,
49}
50
51impl Default for CrdtSubscriptions {
52    fn default() -> Self {
53        Self {
54            state: Mutex::new(SubsState::default()),
55        }
56    }
57}
58
59impl CrdtSubscriptions {
60    pub fn new() -> Arc<Self> {
61        Arc::new(Self::default())
62    }
63
64    /// Register a client's interest in a row. Idempotent — re-subscribing
65    /// the same client to the same row is a no-op (HashSet semantics).
66    pub fn subscribe(&self, client_id: u64, entity: &str, row_id: &str) {
67        let key = (entity.to_string(), row_id.to_string());
68        let mut state = self.state.lock().unwrap();
69        state
70            .by_row
71            .entry(key.clone())
72            .or_default()
73            .insert(client_id);
74        state.by_client.entry(client_id).or_default().insert(key);
75    }
76
77    /// Drop one subscription. Cleans up empty maps so the working set
78    /// stays bounded — long-running connections that subscribe and
79    /// unsubscribe to many rows over their lifetime don't accumulate
80    /// orphan empty entries.
81    pub fn unsubscribe(&self, client_id: u64, entity: &str, row_id: &str) {
82        let key = (entity.to_string(), row_id.to_string());
83        let mut state = self.state.lock().unwrap();
84        if let Some(set) = state.by_row.get_mut(&key) {
85            set.remove(&client_id);
86            if set.is_empty() {
87                state.by_row.remove(&key);
88            }
89        }
90        if let Some(set) = state.by_client.get_mut(&client_id) {
91            set.remove(&key);
92            if set.is_empty() {
93                state.by_client.remove(&client_id);
94            }
95        }
96    }
97
98    /// Drop every subscription for a client (called on WS disconnect or
99    /// when a broadcast send fails for that client). Atomic over the
100    /// whole client's subscription set — broadcast snapshots taken
101    /// concurrently see the client either fully present or fully gone.
102    pub fn unsubscribe_all(&self, client_id: u64) {
103        let mut state = self.state.lock().unwrap();
104        let rows: Vec<(String, String)> = state
105            .by_client
106            .remove(&client_id)
107            .map(|set| set.into_iter().collect())
108            .unwrap_or_default();
109        for key in rows {
110            if let Some(set) = state.by_row.get_mut(&key) {
111                set.remove(&client_id);
112                if set.is_empty() {
113                    state.by_row.remove(&key);
114                }
115            }
116        }
117    }
118
119    /// Snapshot the subscriber set for a row. Returns an owned `Vec`
120    /// rather than a guard so the broadcast hot path doesn't hold the
121    /// mutex during the per-client send loop.
122    pub fn subscribers(&self, entity: &str, row_id: &str) -> Vec<u64> {
123        let key = (entity.to_string(), row_id.to_string());
124        let state = self.state.lock().unwrap();
125        state
126            .by_row
127            .get(&key)
128            .map(|set| set.iter().copied().collect())
129            .unwrap_or_default()
130    }
131
132    /// Diagnostic: total number of (client, row) pairs.
133    pub fn total_subscriptions(&self) -> usize {
134        self.state
135            .lock()
136            .unwrap()
137            .by_row
138            .values()
139            .map(|s| s.len())
140            .sum()
141    }
142}
143
144/// Number of shards for distributing WebSocket clients.
145/// Must be a power of two for even modulo distribution.
146const NUM_SHARDS: usize = 16;
147
148/// Maximum number of outbound messages queued per shard. Once the broadcast
149/// worker thread falls this many behind, the OLDEST queued message is
150/// dropped to make room for the new one. That means slow subscribers can
151/// miss messages — but the alternative (unbounded queue) was OOM when a
152/// single stuck client blocked its shard worker.
153///
154/// Callers that need exact delivery should layer their own retry on top
155/// (the change-log cursor protocol already does this for sync).
156const BROADCAST_QUEUE_DEPTH: usize = 1024;
157
158/// Read timeout on each WebSocket read. Kept low so the mutex guarding the
159/// socket is released frequently, letting the broadcaster get its turn even
160/// if the client never sends anything. Previously this was 120s, which meant
161/// one quiet client could wedge the shard's writer for up to two minutes.
162const WS_READ_TIMEOUT: Duration = Duration::from_millis(200);
163
164/// One entry per connected client. The socket lives behind its OWN
165/// `Mutex`, not a shard-wide one, so the reader thread's blocking
166/// `socket.read()` doesn't hold a lock that covers every client in the
167/// same shard. The broadcaster iterates the client map (outer lock is
168/// brief — O(count of clients in shard)), then grabs each client's
169/// individual mutex to do the `socket.send`. Contention is now per-
170/// client instead of per-shard.
171type ClientSocket = Arc<Mutex<WebSocket<TcpStream>>>;
172
173/// A single shard holding a subset of WebSocket clients.
174///
175/// The outer `Mutex<HashMap>` is held only for insert/remove and while
176/// enumerating client handles — never across I/O.
177struct Shard {
178    clients: Mutex<HashMap<u64, ClientSocket>>,
179}
180
181impl Shard {
182    fn new() -> Self {
183        Self {
184            clients: Mutex::new(HashMap::new()),
185        }
186    }
187
188    fn add(&self, id: u64, ws: WebSocket<TcpStream>) -> ClientSocket {
189        let handle = Arc::new(Mutex::new(ws));
190        self.clients.lock().unwrap().insert(id, Arc::clone(&handle));
191        handle
192    }
193
194    fn remove(&self, id: u64) {
195        self.clients.lock().unwrap().remove(&id);
196    }
197
198    /// Send a message to all clients in this shard.
199    ///
200    /// Snapshot the client handles under the shard lock, drop the shard
201    /// lock, then contend only with per-client mutexes to do the writes.
202    /// This is what lets a reader thread hold its client's mutex for a
203    /// socket.read() without stalling broadcasts for the whole shard.
204    ///
205    /// `msg` is `Arc<str>` rather than `&str` so the caller can serialize
206    /// the JSON exactly once and share the same allocation across all
207    /// 16 shards. Per-client `Message::Text` still allocates an owned
208    /// String (tungstenite 0.24 requires it), but the broadcast no
209    /// longer pays N copies of the JSON across shard channels.
210    fn broadcast(&self, msg: &Arc<str>) {
211        let handles: Vec<(u64, ClientSocket)> = {
212            let clients = self.clients.lock().unwrap();
213            clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
214        };
215        let mut dead: Vec<u64> = Vec::new();
216        for (id, handle) in handles {
217            // `try_lock` would skip clients whose reader is currently
218            // blocked in read(); we prefer `lock()` here so the occasional
219            // broadcaster wait (bounded by the 200ms read timeout) doesn't
220            // drop the message for that client.
221            let mut guard = match handle.lock() {
222                Ok(g) => g,
223                Err(poisoned) => poisoned.into_inner(),
224            };
225            // Owned String per send is the tungstenite 0.24 contract.
226            // The clone here copies the string contents; sharing the
227            // raw bytes via Utf8Bytes would be the next-level
228            // optimization but requires a tungstenite version bump.
229            if guard.send(Message::Text((**msg).to_string())).is_err() {
230                dead.push(id);
231            }
232        }
233        if !dead.is_empty() {
234            let mut clients = self.clients.lock().unwrap();
235            for id in &dead {
236                clients.remove(id);
237            }
238        }
239    }
240
241    /// Send a binary frame to a SPECIFIC subset of this shard's clients.
242    /// Used by the per-client subscription path — `WsHub::broadcast_binary_to`
243    /// computes which ids each shard owns and calls this with just those.
244    ///
245    /// Same per-client lock pattern as `broadcast` / `broadcast_binary`,
246    /// just filtered up front instead of iterating the whole shard.
247    ///
248    /// Returns the list of client ids whose send failed so the caller
249    /// can also clear those ids from the CRDT subscription registry —
250    /// without that step a dead client's subscription entries linger
251    /// until the reader thread notices the EOF and runs unsubscribe_all,
252    /// which can take up to one read-timeout (200ms) longer than the
253    /// send-side death detection.
254    fn send_binary_to(&self, ids: &[u64], msg: &Arc<[u8]>) -> Vec<u64> {
255        let handles: Vec<(u64, ClientSocket)> = {
256            let clients = self.clients.lock().unwrap();
257            ids.iter()
258                .filter_map(|id| clients.get(id).map(|h| (*id, Arc::clone(h))))
259                .collect()
260        };
261        let mut dead: Vec<u64> = Vec::new();
262        for (id, handle) in handles {
263            let mut guard = match handle.lock() {
264                Ok(g) => g,
265                Err(poisoned) => poisoned.into_inner(),
266            };
267            if guard.send(Message::Binary(msg.to_vec())).is_err() {
268                dead.push(id);
269            }
270        }
271        if !dead.is_empty() {
272            let mut clients = self.clients.lock().unwrap();
273            for id in &dead {
274                clients.remove(id);
275            }
276        }
277        dead
278    }
279
280    /// Binary fanout for CRDT updates. Same per-client lock pattern as
281    /// `broadcast` above; the only difference is `Message::Binary` and
282    /// the payload is `Arc<[u8]>` so a single Loro snapshot allocates
283    /// once and the per-client send pays a refcount bump + the
284    /// tungstenite-required Vec clone.
285    fn broadcast_binary(&self, msg: &Arc<[u8]>) {
286        let handles: Vec<(u64, ClientSocket)> = {
287            let clients = self.clients.lock().unwrap();
288            clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
289        };
290        let mut dead: Vec<u64> = Vec::new();
291        for (id, handle) in handles {
292            let mut guard = match handle.lock() {
293                Ok(g) => g,
294                Err(poisoned) => poisoned.into_inner(),
295            };
296            if guard.send(Message::Binary(msg.to_vec())).is_err() {
297                dead.push(id);
298            }
299        }
300        if !dead.is_empty() {
301            let mut clients = self.clients.lock().unwrap();
302            for id in &dead {
303                clients.remove(id);
304            }
305        }
306    }
307
308    fn count(&self) -> usize {
309        self.clients.lock().unwrap().len()
310    }
311}
312
313/// High-performance WebSocket broadcast hub with sharded client storage.
314///
315/// Supports 10k+ concurrent connections with bounded thread count.
316/// Uses NUM_SHARDS (16) shards to reduce lock contention.
317///
318/// Architecture:
319/// - Client connections are assigned to shards via round-robin (id % NUM_SHARDS).
320/// - Each shard has a dedicated broadcast worker thread that consumes from a channel.
321/// - Broadcast calls are non-blocking for the caller: they push to each shard's channel
322///   and return immediately.
323/// - Read-side threads use 64KB stacks (vs 2-8MB default) to keep memory bounded.
324/// - Total thread count: NUM_SHARDS broadcast workers + 1 per connected client (with
325///   minimal stack), plus the accept thread.
326pub struct WsHub {
327    shards: Vec<Arc<Shard>>,
328    next_id: Mutex<u64>,
329    /// Bounded-capacity senders for each shard's broadcast worker. When
330    /// a send would block because the queue is full, `broadcast_raw` drains
331    /// the oldest queued messages so new ones aren't lost to a stuck worker.
332    ///
333    /// Carries `Arc<str>` so a single broadcast event allocates the JSON
334    /// once and the 16 shard sends are cheap refcount bumps. Was a 16×
335    /// String clone hotspot under high write rates with thousands of
336    /// subscribers per shard.
337    broadcast_txs: Vec<mpsc::SyncSender<Arc<str>>>,
338    /// Matching receivers are held by each worker thread and also exposed
339    /// here so the "drop oldest" fallback can drain them on full. Keeping
340    /// the receiver handle alongside the sender is only safe because mpsc
341    /// lets multiple clones share a queue — here we only consume via the
342    /// worker, and the sender-side uses `try_send` + drain retry.
343    #[allow(dead_code)]
344    queue_depth: usize,
345    /// Per-client CRDT subscriptions. Reader threads register `(entity,
346    /// row_id)` pairs as the client mounts/unmounts useLoroDoc hooks;
347    /// the binary CRDT broadcast path uses `subscribers()` to filter the
348    /// fanout. Wrapped in Arc so the notifier (which holds `Arc<WsHub>`)
349    /// can read the subscriber set without taking an extra lock layer.
350    subscriptions: Arc<CrdtSubscriptions>,
351}
352
353impl WsHub {
354    pub fn new() -> Arc<Self> {
355        let mut shards = Vec::with_capacity(NUM_SHARDS);
356        let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
357
358        for i in 0..NUM_SHARDS {
359            let shard = Arc::new(Shard::new());
360            // Bounded queue — if a broadcast worker stalls, `try_send` fails
361            // with Full and `broadcast_raw` drops the oldest to make room.
362            let (tx, rx) = mpsc::sync_channel::<Arc<str>>(BROADCAST_QUEUE_DEPTH);
363
364            let shard_clone = Arc::clone(&shard);
365            thread::Builder::new()
366                .name(format!("ws-broadcast-{i}"))
367                .spawn(move || {
368                    while let Ok(msg) = rx.recv() {
369                        shard_clone.broadcast(&msg);
370                    }
371                })
372                .expect("Failed to spawn broadcast worker");
373
374            shards.push(shard);
375            broadcast_txs.push(tx);
376        }
377
378        Arc::new(Self {
379            shards,
380            next_id: Mutex::new(0),
381            broadcast_txs,
382            queue_depth: BROADCAST_QUEUE_DEPTH,
383            subscriptions: CrdtSubscriptions::new(),
384        })
385    }
386
387    /// Access the per-client CRDT subscription registry. The notifier
388    /// looks up subscribers via `subscriptions().subscribers(entity, row)`
389    /// and feeds them to `broadcast_binary_to`.
390    pub fn subscriptions(&self) -> &Arc<CrdtSubscriptions> {
391        &self.subscriptions
392    }
393
394    /// Broadcast a change event to ALL connected clients across all shards.
395    /// Non-blocking: pushes to each shard's channel and returns immediately.
396    ///
397    /// Serializes the event JSON exactly once into an `Arc<str>` and
398    /// shares it across the 16 shard senders. Each shard's worker
399    /// thread receives the same Arc and pays only a refcount bump.
400    pub fn broadcast(&self, event: &ChangeEvent) {
401        let json = match serde_json::to_string(event) {
402            Ok(j) => j,
403            Err(_) => return,
404        };
405        let shared: Arc<str> = Arc::from(json.into_boxed_str());
406        self.broadcast_shared(shared);
407    }
408
409    /// Broadcast a raw string message to all clients (used for presence updates).
410    pub fn broadcast_presence(&self, msg: &str) {
411        let shared: Arc<str> = Arc::from(msg.to_string().into_boxed_str());
412        self.broadcast_shared(shared);
413    }
414
415    /// Broadcast a binary frame to every connected client across all
416    /// shards. Used for CRDT updates (see `pylon_router::encode_crdt_frame`
417    /// for the wire shape). The bytes are wrapped in an `Arc` so each
418    /// shard's per-client fanout shares one allocation; the per-send
419    /// `to_vec()` cost is the tungstenite 0.24 contract.
420    ///
421    /// Synchronous fanout — iterates shards directly rather than going
422    /// through the per-shard mpsc workers. CRDT writes happen at most
423    /// once per logical mutation so the throughput shape is "occasional
424    /// burst" not "every keystroke", and direct fanout avoids growing a
425    /// second per-shard channel (Arc<[u8]> can't share the Arc<str>
426    /// channel without an enum, which costs more than the bypass).
427    pub fn broadcast_binary(&self, bytes: Vec<u8>) {
428        let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
429        for shard in &self.shards {
430            shard.broadcast_binary(&shared);
431        }
432    }
433
434    /// Send a binary frame to a specific subset of client IDs only.
435    /// Used by the CRDT broadcast path to fan out only to clients
436    /// subscribed to the row that just changed (instead of every
437    /// connected client). Routes each id to its owning shard via
438    /// `id % NUM_SHARDS`.
439    ///
440    /// `client_ids` typically comes from `CrdtSubscriptions::subscribers`.
441    /// An empty list is a no-op — the row had no subscribers, so the
442    /// CRDT write is durable on the server but no client sees the
443    /// binary frame (they'll learn about the change via the JSON
444    /// change-event broadcast which always fires).
445    pub fn broadcast_binary_to(&self, client_ids: &[u64], bytes: Vec<u8>) {
446        if client_ids.is_empty() {
447            return;
448        }
449        let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
450        // Group ids by shard so each shard's per-client lock is only
451        // grabbed once even if many subscribers landed in the same one.
452        let mut by_shard: Vec<Vec<u64>> = (0..NUM_SHARDS).map(|_| Vec::new()).collect();
453        for id in client_ids {
454            by_shard[(*id as usize) % NUM_SHARDS].push(*id);
455        }
456        for (idx, ids) in by_shard.iter().enumerate() {
457            if ids.is_empty() {
458                continue;
459            }
460            for dead_id in self.shards[idx].send_binary_to(ids, &shared) {
461                // Drop the dead client's subscription entries too —
462                // otherwise they leak until the reader thread's read
463                // timeout fires and runs unsubscribe_all on its own,
464                // and a future broadcast might re-attempt the dead id.
465                self.subscriptions.unsubscribe_all(dead_id);
466            }
467        }
468    }
469
470    /// Send a binary frame to a single client by id. Used by the
471    /// subscribe path: when a client subscribes to a row, the server
472    /// immediately ships the current snapshot so the new subscriber
473    /// has the up-to-date state without waiting for the next write.
474    pub fn send_binary_to_one(&self, client_id: u64, bytes: Vec<u8>) {
475        let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
476        let shard_idx = (client_id as usize) % NUM_SHARDS;
477        for dead_id in self.shards[shard_idx].send_binary_to(&[client_id], &shared) {
478            self.subscriptions.unsubscribe_all(dead_id);
479        }
480    }
481
482    /// Internal: fan out a single shared message to every shard worker.
483    ///
484    /// Uses `try_send`; on full we log once (per call) and drop the message
485    /// for that shard. Previously the channel was unbounded, so a stuck
486    /// worker thread would grow memory until OOM. The new bounded queue
487    /// means a slow/stuck subscriber at worst loses broadcast events —
488    /// correctness for critical data still comes through the change-log
489    /// cursor on a reconnect.
490    fn broadcast_shared(&self, msg: Arc<str>) {
491        for tx in &self.broadcast_txs {
492            match tx.try_send(Arc::clone(&msg)) {
493                Ok(()) => {}
494                Err(mpsc::TrySendError::Full(_)) => {
495                    tracing::warn!("[ws] broadcast queue full — dropping event for one shard");
496                }
497                Err(mpsc::TrySendError::Disconnected(_)) => {
498                    // Worker exited (shutdown). Silent.
499                }
500            }
501        }
502    }
503
504    /// Assign a client to a shard via round-robin and register it.
505    /// Returns `(id, socket_handle)` — the caller keeps the handle and uses
506    /// it for reads; the shard also keeps an Arc clone for broadcasts.
507    fn add_client(&self, ws: WebSocket<TcpStream>) -> (u64, ClientSocket) {
508        let mut next_id = self.next_id.lock().unwrap();
509        let id = *next_id;
510        *next_id += 1;
511        let shard_idx = (id as usize) % NUM_SHARDS;
512        let handle = self.shards[shard_idx].add(id, ws);
513        (id, handle)
514    }
515
516    fn remove_client(&self, id: u64) {
517        let shard_idx = (id as usize) % NUM_SHARDS;
518        self.shards[shard_idx].remove(id);
519    }
520
521    /// Total number of connected clients across all shards.
522    pub fn client_count(&self) -> usize {
523        self.shards.iter().map(|s| s.count()).sum()
524    }
525}
526
527/// Snapshot fetcher: given the caller's auth context + `(entity,
528/// row_id)`, return the encoded binary CRDT frame for the row's
529/// current state, or `None` if either the caller can't read the row
530/// (read policy denies) or the row has no snapshot (uninitialized
531/// CRDT or non-CRDT entity).
532///
533/// Auth context is passed in (rather than checked at the WS layer)
534/// because the policy engine + DataStore handles live in the runtime
535/// crate. Without this check an authenticated client could subscribe
536/// to any `(entity, row_id)` and receive every binary CRDT frame
537/// even for rows their query policy would reject — a silent read-
538/// policy bypass.
539///
540/// Wrapped in an Arc<dyn Fn> so the runtime can build it once, capturing
541/// the LoroStore + PolicyEngine handles, and hand the same closure to
542/// every accepted connection.
543pub type SnapshotFetcher =
544    Arc<dyn Fn(&pylon_auth::AuthContext, &str, &str) -> Option<Vec<u8>> + Send + Sync>;
545
546/// Start the WebSocket server on the given port.
547///
548/// The accept loop runs on the calling thread (blocking). Each accepted
549/// connection spawns a lightweight reader thread with a 64KB stack.
550/// Broadcast writes are handled by the shard worker threads, not by
551/// per-client threads.
552///
553/// The session store is required: every connection must present a valid
554/// bearer token (Authorization header or `bearer.<token>` subprotocol —
555/// browsers can't set WS headers directly). Previously the notifier hub
556/// accepted any connection and streamed every ChangeEvent/presence event
557/// to it, which was a silent read-policy bypass.
558///
559/// `snapshot_fetcher` is optional — when present, the reader will ship
560/// the current CRDT snapshot to the subscribing client immediately on
561/// `crdt-subscribe`, so the new tab sees the latest converged state
562/// without waiting for the next write. When absent, subscribe is still
563/// recorded but the catch-up frame is skipped.
564pub fn start_ws_server(
565    hub: Arc<WsHub>,
566    sessions: Arc<SessionStore>,
567    port: u16,
568    snapshot_fetcher: Option<SnapshotFetcher>,
569) {
570    let addr = format!("0.0.0.0:{port}");
571    let listener = match TcpListener::bind(&addr) {
572        Ok(l) => l,
573        Err(e) => {
574            tracing::warn!("[ws] Failed to bind on {addr}: {e}");
575            return;
576        }
577    };
578
579    tracing::warn!(
580        "[ws] WebSocket server listening on ws://localhost:{port} (sharded, {NUM_SHARDS} shards)"
581    );
582
583    let ip_counter = Arc::new(IpConnCounter::default());
584
585    for stream in listener.incoming() {
586        let stream = match stream {
587            Ok(s) => s,
588            Err(_) => continue,
589        };
590
591        // Per-IP connection cap: reject BEFORE the handshake so a cheap
592        // connect storm doesn't force us through tungstenite's HTTP parse
593        // and the session-resolve round trip. The guard is dropped when
594        // the reader thread exits (or fails to start), freeing the slot.
595        let ip = match stream.peer_addr() {
596            Ok(addr) => addr.ip(),
597            Err(_) => continue,
598        };
599        let guard = match ip_counter.acquire(ip) {
600            Some(g) => g,
601            None => {
602                // Ignore: let the client re-try after an existing connection
603                // closes. Previously an IP could open unbounded connections
604                // and each one spawned a thread + held a per-client mutex.
605                continue;
606            }
607        };
608
609        let hub = Arc::clone(&hub);
610        let sessions = Arc::clone(&sessions);
611        let fetcher = snapshot_fetcher.clone();
612        // Spawn a reader thread per client with a small stack.
613        // 64KB stack * 10k connections = ~640MB, vs 2-8MB default * 10k = 20-80GB.
614        let spawn_result = thread::Builder::new()
615            .name("ws-client".into())
616            .stack_size(64 * 1024)
617            .spawn(move || {
618                // Holding `guard` for the life of the connection thread is
619                // what makes the decrement-on-disconnect contract work. Not
620                // `let _ = guard;` — that drops immediately.
621                let _conn_slot = guard;
622                handle_ws_connection(hub, sessions, stream, fetcher);
623            });
624        if spawn_result.is_err() {
625            // Thread creation failed — guard is already dropped here, slot
626            // returned. We deliberately don't call `continue` before the
627            // spawn: we've paid the acquire cost and want to avoid leaking
628            // a slot under transient thread-limit pressure.
629        }
630    }
631}
632
633/// Handle a single WebSocket client connection.
634///
635/// Sets a read timeout to prevent zombie threads on dead connections.
636/// Handles ping/pong for keepalive, presence/topic message relay,
637/// and clean disconnect with presence broadcast.
638fn handle_ws_connection(
639    hub: Arc<WsHub>,
640    sessions: Arc<SessionStore>,
641    stream: TcpStream,
642    snapshot_fetcher: Option<SnapshotFetcher>,
643) {
644    // Short read timeout bounds how long the PER-CLIENT mutex is held
645    // while this thread is blocked in socket.read(). Each client now has
646    // its own mutex (not a shard-wide one), so a quiet client only stalls
647    // the broadcaster when it's broadcasting to THAT specific client —
648    // other clients in the same shard proceed without contention.
649    stream.set_read_timeout(Some(WS_READ_TIMEOUT)).ok();
650    // Also cap write time. A stuck kernel send (slow client, full send
651    // buffer, dropped packets) would otherwise stall the shard's
652    // broadcast worker holding this client's mutex — backpressure
653    // becomes head-of-line blocking for everyone. Capped at 5s; slow
654    // clients get disconnected rather than stalling the hub.
655    stream.set_write_timeout(Some(WS_READ_TIMEOUT)).ok();
656
657    // Extract the bearer token from the handshake, preferring the
658    // Authorization header (native clients) and falling back to the
659    // `bearer.<token>` WebSocket subprotocol (browsers). We only learn
660    // whether the token is valid AFTER accept_hdr completes, since the
661    // header callback must return synchronously with a Response.
662    let token_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
663    let slot_for_cb = Arc::clone(&token_slot);
664    // Cap WebSocket frame size to bound memory per connection. The
665    // tungstenite default (64 MiB) is too generous — a single client
666    // can shovel huge frames and starve other connections. The cap
667    // applies BIDIRECTIONALLY (server-sent CRDT snapshots are
668    // checked against it too), so the default must accommodate the
669    // largest legitimate snapshot — 16 MiB covers Loro docs with
670    // long histories. Operators tune via PYLON_WS_MAX_FRAME (bytes)
671    // when they have unusually large or unusually small docs.
672    let max_frame: usize = std::env::var("PYLON_WS_MAX_FRAME")
673        .ok()
674        .and_then(|v| v.parse().ok())
675        .unwrap_or(16 * 1024 * 1024);
676    let ws_config = WebSocketConfig {
677        max_message_size: Some(max_frame),
678        max_frame_size: Some(max_frame),
679        ..Default::default()
680    };
681    let ws = match accept_hdr_with_config(
682        stream,
683        move |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
684            let mut chosen_protocol: Option<String> = None;
685            let mut auth: Option<String> = None;
686            for (name, value) in req.headers() {
687                let lower = name.as_str().to_ascii_lowercase();
688                if lower == "authorization" {
689                    if let Ok(v) = value.to_str() {
690                        if let Some(tok) = v.strip_prefix("Bearer ") {
691                            auth = Some(tok.to_string());
692                        }
693                    }
694                } else if lower == "sec-websocket-protocol" {
695                    if let Ok(v) = value.to_str() {
696                        for proto in v.split(',').map(str::trim) {
697                            if let Some(encoded) = proto.strip_prefix("bearer.") {
698                                if let Some(decoded) = percent_decode_token(encoded) {
699                                    auth = auth.or(Some(decoded));
700                                    chosen_protocol = Some(proto.to_string());
701                                    break;
702                                }
703                            }
704                        }
705                    }
706                }
707            }
708            // RFC 6455 §11.3.4 — echo the chosen subprotocol in the response or
709            // browsers will refuse the connection.
710            if let Some(chosen) = chosen_protocol {
711                if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
712                    resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
713                }
714            }
715            *slot_for_cb.lock().unwrap() = auth;
716            Ok(resp)
717        },
718        Some(ws_config),
719    ) {
720        Ok(ws) => ws,
721        Err(_) => return,
722    };
723
724    // Reject unauthenticated or invalid-token handshakes AFTER accept —
725    // tungstenite's handshake callback can't easily return a 401 without
726    // a custom error response, and we already have the socket open for
727    // a clean close frame.
728    let token = token_slot.lock().unwrap().clone();
729    let auth_ctx = sessions.resolve(token.as_deref());
730    if auth_ctx.user_id.is_none() && !auth_ctx.is_admin {
731        let mut ws = ws;
732        let _ = ws.close(Some(tungstenite::protocol::CloseFrame {
733            code: tungstenite::protocol::frame::coding::CloseCode::Policy,
734            reason: "unauthorized: bearer token required".into(),
735        }));
736        return;
737    }
738
739    let (client_id, socket_handle) = hub.add_client(ws);
740
741    loop {
742        // Lock this client's socket mutex only for the duration of the
743        // read. With a 5s read timeout, broadcasters waiting to send to
744        // THIS client wait at most 5s. Other clients are never blocked
745        // by this lock — they have their own.
746        let msg = {
747            let mut guard = match socket_handle.lock() {
748                Ok(g) => g,
749                Err(poisoned) => poisoned.into_inner(),
750            };
751            guard.read()
752        };
753
754        match msg {
755            Ok(Message::Text(text)) => {
756                // Parse once and dispatch on the type field instead of
757                // matching prefix bytes — that approach silently dropped
758                // valid JSON with whitespace, key reordering, or any
759                // other formatting variation. Non-object / no-`type`
760                // messages are ignored.
761                let parsed: serde_json::Value = match serde_json::from_str(&text) {
762                    Ok(v) => v,
763                    Err(_) => continue,
764                };
765                let kind = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
766                match kind {
767                    "presence" | "topic" => {
768                        // Stamp the authenticated sender server-side,
769                        // overriding any client-provided `from`. Without
770                        // this, any client could spoof presence/topic
771                        // events as another user — every connected
772                        // client would see a forged "alice typed…"
773                        // message attributed to alice.
774                        let mut stamped = parsed.clone();
775                        if let Some(obj) = stamped.as_object_mut() {
776                            let from = auth_ctx
777                                .user_id
778                                .clone()
779                                .unwrap_or_else(|| "admin".to_string());
780                            obj.insert("from".into(), serde_json::Value::String(from));
781                        }
782                        hub.broadcast_presence(&stamped.to_string());
783                    }
784                    "crdt-subscribe" | "crdt-unsubscribe" => handle_crdt_control(
785                        &hub,
786                        client_id,
787                        &auth_ctx,
788                        kind,
789                        &parsed,
790                        snapshot_fetcher.as_ref(),
791                    ),
792                    _ => {}
793                }
794            }
795            Ok(Message::Ping(data)) => {
796                // Respond with pong to keep the connection alive.
797                if let Ok(mut guard) = socket_handle.lock() {
798                    let _ = guard.send(Message::Pong(data));
799                }
800            }
801            Ok(Message::Close(_)) => {
802                // Drop every CRDT subscription this client held BEFORE
803                // remove_client so the broadcast path can never look up
804                // a stale client_id between the two ops.
805                hub.subscriptions.unsubscribe_all(client_id);
806                hub.remove_client(client_id);
807                let disconnect = serde_json::json!({
808                    "type": "presence",
809                    "event": "disconnect",
810                    "clientId": client_id,
811                });
812                hub.broadcast_presence(&disconnect.to_string());
813                break;
814            }
815            Err(tungstenite::Error::Io(io_err))
816                if io_err.kind() == std::io::ErrorKind::WouldBlock
817                    || io_err.kind() == std::io::ErrorKind::TimedOut =>
818            {
819                // Read timed out — this is EXPECTED with the short
820                // timeout. In theory the mutex is released between
821                // iterations, but `std::sync::Mutex` is not fair: a tight
822                // loop of lock→read→unlock→lock starves the broadcaster
823                // that's been waiting on the same mutex. Explicitly sleep
824                // for a tick so the broadcaster gets scheduled. 1ms is
825                // long enough to hand off, short enough that client→server
826                // latency stays sub-5ms.
827                std::thread::sleep(std::time::Duration::from_millis(1));
828                continue;
829            }
830            Err(_) => {
831                hub.subscriptions.unsubscribe_all(client_id);
832                hub.remove_client(client_id);
833                let disconnect = serde_json::json!({
834                    "type": "presence",
835                    "event": "disconnect",
836                    "clientId": client_id,
837                });
838                hub.broadcast_presence(&disconnect.to_string());
839                break;
840            }
841            _ => {}
842        }
843    }
844}
845
846/// Apply a parsed `crdt-subscribe` / `crdt-unsubscribe` control
847/// message. Both messages have the shape:
848///
849///   { "type": "crdt-subscribe",   "entity": "<E>", "rowId": "<id>" }
850///   { "type": "crdt-unsubscribe", "entity": "<E>", "rowId": "<id>" }
851///
852/// On subscribe the snapshot fetcher checks read policy for the
853/// caller's auth context — if the caller can't read the row we
854/// register no subscription and ship nothing back, so a malicious
855/// client can't peek at a row their query policy would block by
856/// just subscribing to its CRDT stream.
857///
858/// Malformed messages are silently dropped — there's no client-visible
859/// ACK protocol, so a typo in the payload would just look like a
860/// row that never receives updates. Logging would invite a noise
861/// channel for misbehaving clients.
862fn handle_crdt_control(
863    hub: &Arc<WsHub>,
864    client_id: u64,
865    auth_ctx: &pylon_auth::AuthContext,
866    kind: &str,
867    parsed: &serde_json::Value,
868    snapshot_fetcher: Option<&SnapshotFetcher>,
869) {
870    let entity = match parsed.get("entity").and_then(|v| v.as_str()) {
871        Some(e) if !e.is_empty() => e,
872        _ => return,
873    };
874    let row_id = match parsed
875        .get("rowId")
876        .or_else(|| parsed.get("row_id"))
877        .and_then(|v| v.as_str())
878    {
879        Some(r) if !r.is_empty() => r,
880        _ => return,
881    };
882
883    match kind {
884        "crdt-subscribe" => {
885            // Authz check happens INSIDE the fetcher (it has access to
886            // the policy engine + DataStore). When a fetcher is wired
887            // and returns None, the caller is either denied or the row
888            // doesn't exist — in both cases we refuse to register the
889            // subscription so a denied caller can't silently hold an
890            // open slot waiting for future writes.
891            //
892            // When no fetcher is wired (test harnesses, future
893            // workers backend without DataStore access) we trust the
894            // caller and register without the auth gate. Production
895            // server.rs always wires one, so this loophole is
896            // unreachable in deployed configurations.
897            let snapshot = snapshot_fetcher.and_then(|f| f(auth_ctx, entity, row_id));
898            let allow_subscribe = snapshot_fetcher.is_none() || snapshot.is_some();
899            if allow_subscribe {
900                hub.subscriptions.subscribe(client_id, entity, row_id);
901                if let Some(bytes) = snapshot {
902                    hub.send_binary_to_one(client_id, bytes);
903                }
904            }
905        }
906        "crdt-unsubscribe" => {
907            hub.subscriptions.unsubscribe(client_id, entity, row_id);
908        }
909        _ => {}
910    }
911}
912
913/// Strict percent-decode for the `bearer.<token>` subprotocol. Returns
914/// `None` on any malformed byte rather than silently passing garbage
915/// through to the session store (which would just fail to resolve and
916/// look like a plain unauth attempt).
917fn percent_decode_token(s: &str) -> Option<String> {
918    let bytes = s.as_bytes();
919    let mut out = Vec::with_capacity(bytes.len());
920    let mut i = 0;
921    while i < bytes.len() {
922        match bytes[i] {
923            b'%' => {
924                if i + 2 >= bytes.len() {
925                    return None;
926                }
927                let hi = (bytes[i + 1] as char).to_digit(16)?;
928                let lo = (bytes[i + 2] as char).to_digit(16)?;
929                out.push(((hi << 4) | lo) as u8);
930                i += 3;
931            }
932            b'+' => {
933                out.push(b' ');
934                i += 1;
935            }
936            b => {
937                out.push(b);
938                i += 1;
939            }
940        }
941    }
942    String::from_utf8(out).ok()
943}
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948
949    #[test]
950    fn shard_count_starts_at_zero() {
951        let shard = Shard::new();
952        assert_eq!(shard.count(), 0);
953    }
954
955    #[test]
956    fn hub_starts_with_zero_clients() {
957        let hub = WsHub::new();
958        assert_eq!(hub.client_count(), 0);
959    }
960
961    #[test]
962    fn broadcast_to_empty_hub_doesnt_panic() {
963        let hub = WsHub::new();
964        let event = ChangeEvent {
965            seq: 1,
966            entity: "Test".into(),
967            row_id: "1".into(),
968            kind: pylon_sync::ChangeKind::Insert,
969            data: None,
970            timestamp: String::new(),
971        };
972        hub.broadcast(&event);
973        hub.broadcast_presence("test");
974    }
975
976    #[test]
977    fn num_shards_is_power_of_two() {
978        // Power-of-two shard count ensures even distribution with modulo.
979        assert!(
980            NUM_SHARDS.is_power_of_two(),
981            "NUM_SHARDS ({NUM_SHARDS}) must be a power of two for even distribution"
982        );
983    }
984
985    #[test]
986    fn crdt_subscriptions_subscribe_dedups() {
987        let subs = CrdtSubscriptions::default();
988        subs.subscribe(1, "Channel", "abc");
989        subs.subscribe(1, "Channel", "abc");
990        assert_eq!(subs.subscribers("Channel", "abc"), vec![1]);
991        assert_eq!(subs.total_subscriptions(), 1);
992    }
993
994    #[test]
995    fn crdt_subscriptions_returns_all_subscribers() {
996        let subs = CrdtSubscriptions::default();
997        subs.subscribe(1, "Channel", "abc");
998        subs.subscribe(2, "Channel", "abc");
999        subs.subscribe(3, "Channel", "abc");
1000        let mut ids = subs.subscribers("Channel", "abc");
1001        ids.sort();
1002        assert_eq!(ids, vec![1, 2, 3]);
1003    }
1004
1005    #[test]
1006    fn crdt_subscriptions_unsubscribe_cleans_empty_rows() {
1007        let subs = CrdtSubscriptions::default();
1008        subs.subscribe(1, "Channel", "abc");
1009        subs.unsubscribe(1, "Channel", "abc");
1010        assert!(subs.subscribers("Channel", "abc").is_empty());
1011        // total should drop the empty by_row entry, not leave a 0-set
1012        // around forever.
1013        assert_eq!(subs.total_subscriptions(), 0);
1014    }
1015
1016    #[test]
1017    fn crdt_subscriptions_unsubscribe_all_drops_every_row() {
1018        let subs = CrdtSubscriptions::default();
1019        subs.subscribe(1, "Channel", "a");
1020        subs.subscribe(1, "Channel", "b");
1021        subs.subscribe(1, "Message", "m1");
1022        subs.subscribe(2, "Channel", "a"); // someone else, must survive
1023        subs.unsubscribe_all(1);
1024        assert!(subs.subscribers("Channel", "b").is_empty());
1025        assert!(subs.subscribers("Message", "m1").is_empty());
1026        // Client 2 is still there.
1027        assert_eq!(subs.subscribers("Channel", "a"), vec![2]);
1028    }
1029
1030    #[test]
1031    fn crdt_subscriptions_unsubscribe_unknown_client_is_noop() {
1032        let subs = CrdtSubscriptions::default();
1033        subs.unsubscribe(99, "Channel", "abc");
1034        subs.unsubscribe_all(99);
1035        assert_eq!(subs.total_subscriptions(), 0);
1036    }
1037
1038    #[test]
1039    fn crdt_subscriptions_concurrent_subscribe_and_unsubscribe() {
1040        // Hammer subscribe + unsubscribe from many threads to verify
1041        // the single-mutex design keeps by_row and by_client in sync.
1042        // Previous two-mutex version could leave the maps divergent
1043        // under interleaving.
1044        let subs = Arc::new(CrdtSubscriptions::default());
1045        let mut handles = Vec::new();
1046        for client_id in 0..16u64 {
1047            let subs = Arc::clone(&subs);
1048            handles.push(std::thread::spawn(move || {
1049                for i in 0..200 {
1050                    let row = format!("row-{i}");
1051                    subs.subscribe(client_id, "Channel", &row);
1052                    subs.unsubscribe(client_id, "Channel", &row);
1053                }
1054            }));
1055        }
1056        for h in handles {
1057            h.join().unwrap();
1058        }
1059        // Every subscribe paired with an unsubscribe — registry must be
1060        // fully drained.
1061        assert_eq!(subs.total_subscriptions(), 0);
1062    }
1063
1064    #[test]
1065    fn crdt_subscriptions_unsubscribe_all_after_concurrent_subscribes() {
1066        let subs = Arc::new(CrdtSubscriptions::default());
1067        let mut handles = Vec::new();
1068        for client_id in 0..8u64 {
1069            let subs = Arc::clone(&subs);
1070            handles.push(std::thread::spawn(move || {
1071                for i in 0..100 {
1072                    let row = format!("row-{i}");
1073                    subs.subscribe(client_id, "Channel", &row);
1074                }
1075            }));
1076        }
1077        for h in handles {
1078            h.join().unwrap();
1079        }
1080        // Now wipe each client and confirm no orphan rows remain.
1081        for client_id in 0..8u64 {
1082            subs.unsubscribe_all(client_id);
1083        }
1084        assert_eq!(subs.total_subscriptions(), 0);
1085    }
1086
1087    #[test]
1088    fn shard_assignment_distributes_evenly() {
1089        // Verify that sequential IDs spread across all shards.
1090        let mut counts = vec![0usize; NUM_SHARDS];
1091        for id in 0..(NUM_SHARDS as u64 * 100) {
1092            counts[(id as usize) % NUM_SHARDS] += 1;
1093        }
1094        // Every shard should get exactly 100 clients.
1095        for (i, count) in counts.iter().enumerate() {
1096            assert_eq!(*count, 100, "Shard {i} got {count} clients, expected 100");
1097        }
1098    }
1099}