1use std::collections::HashMap;
2use std::io::Write;
3use std::net::{TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::sync::{Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8
9use pylon_sync::ChangeEvent;
10
11use crate::ip_limit::{IpConnCounter, IpConnGuard};
12
13const NUM_SHARDS: usize = 16;
14
15struct SseClient {
20 stream: TcpStream,
21 _guard: Option<IpConnGuard>,
22}
23
24const BROADCAST_QUEUE_DEPTH: usize = 1024;
29
30struct SseShard {
34 clients: Mutex<HashMap<u64, SseClient>>,
35}
36
37impl SseShard {
38 fn new() -> Self {
39 Self {
40 clients: Mutex::new(HashMap::new()),
41 }
42 }
43
44 fn add(&self, id: u64, stream: TcpStream, guard: Option<IpConnGuard>) {
45 self.clients.lock().unwrap().insert(
46 id,
47 SseClient {
48 stream,
49 _guard: guard,
50 },
51 );
52 }
53
54 #[allow(dead_code)]
55 fn remove(&self, id: u64) {
56 self.clients.lock().unwrap().remove(&id);
57 }
58
59 fn broadcast(&self, data: &str) -> Vec<u64> {
62 let sse_data = format!("data: {data}\n\n");
63 let mut clients = self.clients.lock().unwrap();
64 let mut dead = Vec::new();
65 for (id, client) in clients.iter_mut() {
66 if client.stream.write_all(sse_data.as_bytes()).is_err()
67 || client.stream.flush().is_err()
68 {
69 dead.push(*id);
70 }
71 }
72 for id in &dead {
73 clients.remove(id);
74 }
75 dead
76 }
77
78 fn keepalive(&self) {
80 let mut clients = self.clients.lock().unwrap();
81 let mut dead = Vec::new();
82 for (id, client) in clients.iter_mut() {
83 if client.stream.write_all(b": keepalive\n\n").is_err()
84 || client.stream.flush().is_err()
85 {
86 dead.push(*id);
87 }
88 }
89 for id in dead {
90 clients.remove(&id);
91 }
92 }
93
94 fn count(&self) -> usize {
95 self.clients.lock().unwrap().len()
96 }
97}
98
99pub struct SseHub {
108 shards: Vec<Arc<SseShard>>,
109 next_id: Mutex<u64>,
110 broadcast_txs: Vec<mpsc::SyncSender<String>>,
111}
112
113impl SseHub {
114 pub fn new() -> Arc<Self> {
115 let mut shards = Vec::with_capacity(NUM_SHARDS);
116 let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
117
118 for i in 0..NUM_SHARDS {
119 let shard = Arc::new(SseShard::new());
120 let (tx, rx) = mpsc::sync_channel::<String>(BROADCAST_QUEUE_DEPTH);
121
122 let shard_clone = Arc::clone(&shard);
125 thread::Builder::new()
126 .name(format!("sse-broadcast-{i}"))
127 .spawn(move || {
128 while let Ok(msg) = rx.recv() {
129 shard_clone.broadcast(&msg);
130 }
131 })
132 .expect("Failed to spawn SSE broadcast worker");
133
134 let shard_ka = Arc::clone(&shard);
137 thread::Builder::new()
138 .name(format!("sse-keepalive-{i}"))
139 .spawn(move || loop {
140 thread::sleep(Duration::from_secs(30));
141 shard_ka.keepalive();
142 })
143 .expect("Failed to spawn SSE keepalive worker");
144
145 shards.push(shard);
146 broadcast_txs.push(tx);
147 }
148
149 Arc::new(Self {
150 shards,
151 next_id: Mutex::new(0),
152 broadcast_txs,
153 })
154 }
155
156 pub fn broadcast(&self, event: &ChangeEvent) {
158 let json = match serde_json::to_string(event) {
159 Ok(j) => j,
160 Err(_) => return,
161 };
162 self.send_to_all(&json);
163 }
164
165 pub fn broadcast_message(&self, msg: &str) {
167 self.send_to_all(msg);
168 }
169
170 fn send_to_all(&self, msg: &str) {
172 for tx in &self.broadcast_txs {
173 match tx.try_send(msg.to_string()) {
174 Ok(()) => {}
175 Err(mpsc::TrySendError::Full(_)) => {
176 tracing::warn!("[sse] broadcast queue full — dropping event for one shard");
177 }
178 Err(mpsc::TrySendError::Disconnected(_)) => {}
179 }
180 }
181 }
182
183 fn add_client(&self, stream: TcpStream, guard: Option<IpConnGuard>) -> u64 {
190 let mut next_id = self.next_id.lock().unwrap();
191 let id = *next_id;
192 *next_id += 1;
193 let shard_idx = (id as usize) % NUM_SHARDS;
194 self.shards[shard_idx].add(id, stream, guard);
195 id
196 }
197
198 pub fn client_count(&self) -> usize {
200 self.shards.iter().map(|s| s.count()).sum()
201 }
202}
203
204pub fn start_sse_server(hub: Arc<SseHub>, port: u16) {
210 let addr = format!("0.0.0.0:{port}");
211 let listener = match TcpListener::bind(&addr) {
212 Ok(l) => l,
213 Err(e) => {
214 tracing::warn!("[sse] Failed to bind on {addr}: {e}");
215 return;
216 }
217 };
218
219 tracing::warn!(
220 "[sse] SSE server listening on http://localhost:{port}/events (sharded, {NUM_SHARDS} shards)"
221 );
222
223 let ip_counter = Arc::new(IpConnCounter::default());
227
228 for stream in listener.incoming() {
229 let stream = match stream {
230 Ok(s) => s,
231 Err(_) => continue,
232 };
233
234 let ip = match stream.peer_addr() {
235 Ok(addr) => addr.ip(),
236 Err(_) => continue,
237 };
238 let guard = match ip_counter.acquire(ip) {
239 Some(g) => g,
240 None => continue,
241 };
242
243 let hub = Arc::clone(&hub);
244 thread::Builder::new()
248 .name("sse-accept".into())
249 .stack_size(64 * 1024)
250 .spawn(move || {
251 handle_sse_connection(hub, stream, guard);
252 })
253 .ok();
254 }
255}
256
257fn handle_sse_connection(hub: Arc<SseHub>, mut stream: TcpStream, guard: IpConnGuard) {
258 let mut buf = [0u8; 2048];
261 let _ = std::io::Read::read(&mut stream, &mut buf);
262
263 stream.set_nodelay(true).ok();
265
266 let headers = "HTTP/1.1 200 OK\r\n\
268 Content-Type: text/event-stream\r\n\
269 Cache-Control: no-cache\r\n\
270 Connection: keep-alive\r\n\
271 Access-Control-Allow-Origin: *\r\n\
272 X-Content-Type-Options: nosniff\r\n\
273 \r\n";
274
275 if stream.write_all(headers.as_bytes()).is_err() {
276 return;
277 }
278 if stream.write_all(b": connected\n\n").is_err() {
279 return;
280 }
281 let _ = stream.flush();
282
283 hub.add_client(stream, Some(guard));
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn hub_starts_with_correct_shard_count() {
296 let hub = SseHub::new();
297 assert_eq!(hub.shards.len(), NUM_SHARDS);
298 assert_eq!(hub.broadcast_txs.len(), NUM_SHARDS);
299 }
300
301 #[test]
302 fn hub_starts_empty() {
303 let hub = SseHub::new();
304 assert_eq!(hub.client_count(), 0);
305 }
306
307 #[test]
308 fn broadcast_on_empty_hub_does_not_panic() {
309 let hub = SseHub::new();
310 hub.broadcast_message("hello");
311 thread::sleep(Duration::from_millis(50));
313 assert_eq!(hub.client_count(), 0);
314 }
315
316 #[test]
317 fn keepalive_on_empty_shard_does_not_panic() {
318 let shard = SseShard::new();
319 shard.keepalive();
320 assert_eq!(shard.count(), 0);
321 }
322
323 #[test]
324 fn broadcast_on_empty_shard_returns_no_dead() {
325 let shard = SseShard::new();
326 let dead = shard.broadcast("test");
327 assert!(dead.is_empty());
328 }
329
330 #[test]
331 fn client_ids_are_sequential() {
332 let hub = SseHub::new();
333 let mut next_id = hub.next_id.lock().unwrap();
335 assert_eq!(*next_id, 0);
336 *next_id = 5;
337 drop(next_id);
338 }
340}