Skip to main content

pylon_runtime/
sse.rs

1use std::collections::HashMap;
2use std::io::Write;
3use std::net::{TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::sync::{Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8
9use pylon_sync::ChangeEvent;
10
11use crate::ip_limit::{IpConnCounter, IpConnGuard};
12
13const NUM_SHARDS: usize = 16;
14
15/// Per-client state in the shard map. The `_guard` is held for the lifetime
16/// of the connection — dropping it (when the client is removed) releases
17/// the client's slot in the per-IP connection counter. Without this, a
18/// crash-loopy browser could open unlimited SSE streams.
19struct SseClient {
20    stream: TcpStream,
21    _guard: Option<IpConnGuard>,
22}
23
24/// Same rationale as the WS hub: bounded queue + drop-oldest-on-full so a
25/// stuck subscriber can't balloon memory on the broadcast path. Clients
26/// that miss events catch up via the change-log cursor protocol on
27/// reconnect — SSE is a notify-sooner, not a durable-delivery transport.
28const BROADCAST_QUEUE_DEPTH: usize = 1024;
29
30/// A single shard holding a subset of SSE clients, protected by its own lock.
31/// Sharding reduces contention: concurrent broadcasts only block within the
32/// same shard, not across the entire client set.
33struct SseShard {
34    clients: Mutex<HashMap<u64, SseClient>>,
35}
36
37impl SseShard {
38    fn new() -> Self {
39        Self {
40            clients: Mutex::new(HashMap::new()),
41        }
42    }
43
44    fn add(&self, id: u64, stream: TcpStream, guard: Option<IpConnGuard>) {
45        self.clients.lock().unwrap().insert(
46            id,
47            SseClient {
48                stream,
49                _guard: guard,
50            },
51        );
52    }
53
54    #[allow(dead_code)]
55    fn remove(&self, id: u64) {
56        self.clients.lock().unwrap().remove(&id);
57    }
58
59    /// Send SSE-formatted data to every client in this shard.
60    /// Dead clients (write failures) are removed inline and their IDs returned.
61    fn broadcast(&self, data: &str) -> Vec<u64> {
62        let sse_data = format!("data: {data}\n\n");
63        let mut clients = self.clients.lock().unwrap();
64        let mut dead = Vec::new();
65        for (id, client) in clients.iter_mut() {
66            if client.stream.write_all(sse_data.as_bytes()).is_err()
67                || client.stream.flush().is_err()
68            {
69                dead.push(*id);
70            }
71        }
72        for id in &dead {
73            clients.remove(id);
74        }
75        dead
76    }
77
78    /// Send an SSE comment keepalive to every client. Removes dead clients.
79    fn keepalive(&self) {
80        let mut clients = self.clients.lock().unwrap();
81        let mut dead = Vec::new();
82        for (id, client) in clients.iter_mut() {
83            if client.stream.write_all(b": keepalive\n\n").is_err()
84                || client.stream.flush().is_err()
85            {
86                dead.push(*id);
87            }
88        }
89        for id in dead {
90            clients.remove(&id);
91        }
92    }
93
94    fn count(&self) -> usize {
95        self.clients.lock().unwrap().len()
96    }
97}
98
99/// Sharded SSE broadcast hub.
100///
101/// 16 shards partition clients by ID. Each shard has a dedicated broadcast
102/// worker thread (receives messages via `mpsc::channel`) and a keepalive
103/// thread that sends SSE comments every 30 seconds.
104///
105/// This means 10k connected SSE clients require only 32 background threads
106/// (16 broadcast + 16 keepalive) instead of 10k threads in the old design.
107pub struct SseHub {
108    shards: Vec<Arc<SseShard>>,
109    next_id: Mutex<u64>,
110    broadcast_txs: Vec<mpsc::SyncSender<String>>,
111}
112
113impl SseHub {
114    pub fn new() -> Arc<Self> {
115        let mut shards = Vec::with_capacity(NUM_SHARDS);
116        let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
117
118        for i in 0..NUM_SHARDS {
119            let shard = Arc::new(SseShard::new());
120            let (tx, rx) = mpsc::sync_channel::<String>(BROADCAST_QUEUE_DEPTH);
121
122            // Broadcast worker: drains the channel and writes to every client
123            // in this shard. Runs until the channel is dropped (hub teardown).
124            let shard_clone = Arc::clone(&shard);
125            thread::Builder::new()
126                .name(format!("sse-broadcast-{i}"))
127                .spawn(move || {
128                    while let Ok(msg) = rx.recv() {
129                        shard_clone.broadcast(&msg);
130                    }
131                })
132                .expect("Failed to spawn SSE broadcast worker");
133
134            // Keepalive worker: sends an SSE comment every 30s to prevent
135            // proxies and load balancers from closing idle connections.
136            let shard_ka = Arc::clone(&shard);
137            thread::Builder::new()
138                .name(format!("sse-keepalive-{i}"))
139                .spawn(move || loop {
140                    thread::sleep(Duration::from_secs(30));
141                    shard_ka.keepalive();
142                })
143                .expect("Failed to spawn SSE keepalive worker");
144
145            shards.push(shard);
146            broadcast_txs.push(tx);
147        }
148
149        Arc::new(Self {
150            shards,
151            next_id: Mutex::new(0),
152            broadcast_txs,
153        })
154    }
155
156    /// Broadcast a `ChangeEvent` to all connected SSE clients.
157    pub fn broadcast(&self, event: &ChangeEvent) {
158        let json = match serde_json::to_string(event) {
159            Ok(j) => j,
160            Err(_) => return,
161        };
162        self.send_to_all(&json);
163    }
164
165    /// Broadcast an arbitrary string message (e.g. presence/topic updates).
166    pub fn broadcast_message(&self, msg: &str) {
167        self.send_to_all(msg);
168    }
169
170    /// Internal: bounded-queue send to all shard workers.
171    fn send_to_all(&self, msg: &str) {
172        for tx in &self.broadcast_txs {
173            match tx.try_send(msg.to_string()) {
174                Ok(()) => {}
175                Err(mpsc::TrySendError::Full(_)) => {
176                    tracing::warn!("[sse] broadcast queue full — dropping event for one shard");
177                }
178                Err(mpsc::TrySendError::Disconnected(_)) => {}
179            }
180        }
181    }
182
183    /// Register a new SSE client. Returns the assigned client ID.
184    /// The stream is moved into the appropriate shard — the caller should not
185    /// use it after this call. The optional `guard` binds the client's slot
186    /// in the per-IP connection counter to this client's presence in the
187    /// shard map; when the client is removed, the guard drops and the slot
188    /// is returned.
189    fn add_client(&self, stream: TcpStream, guard: Option<IpConnGuard>) -> u64 {
190        let mut next_id = self.next_id.lock().unwrap();
191        let id = *next_id;
192        *next_id += 1;
193        let shard_idx = (id as usize) % NUM_SHARDS;
194        self.shards[shard_idx].add(id, stream, guard);
195        id
196    }
197
198    /// Total number of connected SSE clients across all shards.
199    pub fn client_count(&self) -> usize {
200        self.shards.iter().map(|s| s.count()).sum()
201    }
202}
203
204/// Start the SSE server on the given port.
205///
206/// Accepts TCP connections, performs minimal HTTP parsing, sends SSE headers,
207/// and registers the stream with the hub. The accept thread exits immediately
208/// after registration — no per-client thread is kept alive.
209pub fn start_sse_server(hub: Arc<SseHub>, port: u16) {
210    let addr = format!("0.0.0.0:{port}");
211    let listener = match TcpListener::bind(&addr) {
212        Ok(l) => l,
213        Err(e) => {
214            tracing::warn!("[sse] Failed to bind on {addr}: {e}");
215            return;
216        }
217    };
218
219    tracing::warn!(
220        "[sse] SSE server listening on http://localhost:{port}/events (sharded, {NUM_SHARDS} shards)"
221    );
222
223    // Per-IP cap mirrors the one on /ws. Idle SSE streams are cheap, but a
224    // crash-loopy client can still accumulate thousands of them — this
225    // bounds that.
226    let ip_counter = Arc::new(IpConnCounter::default());
227
228    for stream in listener.incoming() {
229        let stream = match stream {
230            Ok(s) => s,
231            Err(_) => continue,
232        };
233
234        let ip = match stream.peer_addr() {
235            Ok(addr) => addr.ip(),
236            Err(_) => continue,
237        };
238        let guard = match ip_counter.acquire(ip) {
239            Some(g) => g,
240            None => continue,
241        };
242
243        let hub = Arc::clone(&hub);
244        // Lightweight accept thread with a small stack. It reads the HTTP
245        // request, writes SSE headers, registers the stream (transferring
246        // the IP-conn guard into the shard map), then exits.
247        thread::Builder::new()
248            .name("sse-accept".into())
249            .stack_size(64 * 1024)
250            .spawn(move || {
251                handle_sse_connection(hub, stream, guard);
252            })
253            .ok();
254    }
255}
256
257fn handle_sse_connection(hub: Arc<SseHub>, mut stream: TcpStream, guard: IpConnGuard) {
258    // Consume the HTTP request headers. We don't route — any connection
259    // to this port is treated as an SSE subscription.
260    let mut buf = [0u8; 2048];
261    let _ = std::io::Read::read(&mut stream, &mut buf);
262
263    // Disable Nagle for lower-latency event delivery.
264    stream.set_nodelay(true).ok();
265
266    // Send SSE response headers.
267    let headers = "HTTP/1.1 200 OK\r\n\
268                   Content-Type: text/event-stream\r\n\
269                   Cache-Control: no-cache\r\n\
270                   Connection: keep-alive\r\n\
271                   Access-Control-Allow-Origin: *\r\n\
272                   X-Content-Type-Options: nosniff\r\n\
273                   \r\n";
274
275    if stream.write_all(headers.as_bytes()).is_err() {
276        return;
277    }
278    if stream.write_all(b": connected\n\n").is_err() {
279        return;
280    }
281    let _ = stream.flush();
282
283    // Hand the stream AND the IP-conn guard to the hub. The shard's
284    // broadcast and keepalive workers now own writes; the guard is
285    // released when the client is dropped from the shard map. This
286    // thread exits.
287    hub.add_client(stream, Some(guard));
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn hub_starts_with_correct_shard_count() {
296        let hub = SseHub::new();
297        assert_eq!(hub.shards.len(), NUM_SHARDS);
298        assert_eq!(hub.broadcast_txs.len(), NUM_SHARDS);
299    }
300
301    #[test]
302    fn hub_starts_empty() {
303        let hub = SseHub::new();
304        assert_eq!(hub.client_count(), 0);
305    }
306
307    #[test]
308    fn broadcast_on_empty_hub_does_not_panic() {
309        let hub = SseHub::new();
310        hub.broadcast_message("hello");
311        // Give broadcast workers time to process.
312        thread::sleep(Duration::from_millis(50));
313        assert_eq!(hub.client_count(), 0);
314    }
315
316    #[test]
317    fn keepalive_on_empty_shard_does_not_panic() {
318        let shard = SseShard::new();
319        shard.keepalive();
320        assert_eq!(shard.count(), 0);
321    }
322
323    #[test]
324    fn broadcast_on_empty_shard_returns_no_dead() {
325        let shard = SseShard::new();
326        let dead = shard.broadcast("test");
327        assert!(dead.is_empty());
328    }
329
330    #[test]
331    fn client_ids_are_sequential() {
332        let hub = SseHub::new();
333        // Verify the ID counter increments correctly.
334        let mut next_id = hub.next_id.lock().unwrap();
335        assert_eq!(*next_id, 0);
336        *next_id = 5;
337        drop(next_id);
338        // Next add_client would get ID 5, distributing to shard 5 % 16 = 5.
339    }
340}