use std::collections::HashMap;
use std::io::Write;
use std::net::{TcpListener, TcpStream};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use pylon_sync::ChangeEvent;
use crate::ip_limit::{IpConnCounter, IpConnGuard};
const NUM_SHARDS: usize = 16;
struct SseClient {
stream: TcpStream,
_guard: Option<IpConnGuard>,
}
const BROADCAST_QUEUE_DEPTH: usize = 1024;
struct SseShard {
clients: Mutex<HashMap<u64, SseClient>>,
}
impl SseShard {
fn new() -> Self {
Self {
clients: Mutex::new(HashMap::new()),
}
}
fn add(&self, id: u64, stream: TcpStream, guard: Option<IpConnGuard>) {
self.clients.lock().unwrap().insert(
id,
SseClient {
stream,
_guard: guard,
},
);
}
#[allow(dead_code)]
fn remove(&self, id: u64) {
self.clients.lock().unwrap().remove(&id);
}
fn broadcast(&self, data: &str) -> Vec<u64> {
let sse_data = format!("data: {data}\n\n");
let mut clients = self.clients.lock().unwrap();
let mut dead = Vec::new();
for (id, client) in clients.iter_mut() {
if client.stream.write_all(sse_data.as_bytes()).is_err()
|| client.stream.flush().is_err()
{
dead.push(*id);
}
}
for id in &dead {
clients.remove(id);
}
dead
}
fn keepalive(&self) {
let mut clients = self.clients.lock().unwrap();
let mut dead = Vec::new();
for (id, client) in clients.iter_mut() {
if client.stream.write_all(b": keepalive\n\n").is_err()
|| client.stream.flush().is_err()
{
dead.push(*id);
}
}
for id in dead {
clients.remove(&id);
}
}
fn count(&self) -> usize {
self.clients.lock().unwrap().len()
}
}
pub struct SseHub {
shards: Vec<Arc<SseShard>>,
next_id: Mutex<u64>,
broadcast_txs: Vec<mpsc::SyncSender<String>>,
}
impl SseHub {
pub fn new() -> Arc<Self> {
let mut shards = Vec::with_capacity(NUM_SHARDS);
let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
for i in 0..NUM_SHARDS {
let shard = Arc::new(SseShard::new());
let (tx, rx) = mpsc::sync_channel::<String>(BROADCAST_QUEUE_DEPTH);
let shard_clone = Arc::clone(&shard);
thread::Builder::new()
.name(format!("sse-broadcast-{i}"))
.spawn(move || {
while let Ok(msg) = rx.recv() {
shard_clone.broadcast(&msg);
}
})
.expect("Failed to spawn SSE broadcast worker");
let shard_ka = Arc::clone(&shard);
thread::Builder::new()
.name(format!("sse-keepalive-{i}"))
.spawn(move || loop {
thread::sleep(Duration::from_secs(30));
shard_ka.keepalive();
})
.expect("Failed to spawn SSE keepalive worker");
shards.push(shard);
broadcast_txs.push(tx);
}
Arc::new(Self {
shards,
next_id: Mutex::new(0),
broadcast_txs,
})
}
pub fn broadcast(&self, event: &ChangeEvent) {
let json = match serde_json::to_string(event) {
Ok(j) => j,
Err(_) => return,
};
self.send_to_all(&json);
}
pub fn broadcast_message(&self, msg: &str) {
self.send_to_all(msg);
}
fn send_to_all(&self, msg: &str) {
for tx in &self.broadcast_txs {
match tx.try_send(msg.to_string()) {
Ok(()) => {}
Err(mpsc::TrySendError::Full(_)) => {
tracing::warn!("[sse] broadcast queue full — dropping event for one shard");
}
Err(mpsc::TrySendError::Disconnected(_)) => {}
}
}
}
fn add_client(&self, stream: TcpStream, guard: Option<IpConnGuard>) -> u64 {
let mut next_id = self.next_id.lock().unwrap();
let id = *next_id;
*next_id += 1;
let shard_idx = (id as usize) % NUM_SHARDS;
self.shards[shard_idx].add(id, stream, guard);
id
}
pub fn client_count(&self) -> usize {
self.shards.iter().map(|s| s.count()).sum()
}
}
pub fn start_sse_server(hub: Arc<SseHub>, port: u16) {
let addr = format!("0.0.0.0:{port}");
let listener = match TcpListener::bind(&addr) {
Ok(l) => l,
Err(e) => {
tracing::warn!("[sse] Failed to bind on {addr}: {e}");
return;
}
};
tracing::warn!(
"[sse] SSE server listening on http://localhost:{port}/events (sharded, {NUM_SHARDS} shards)"
);
let ip_counter = Arc::new(IpConnCounter::default());
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(_) => continue,
};
let ip = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(_) => continue,
};
let guard = match ip_counter.acquire(ip) {
Some(g) => g,
None => continue,
};
let hub = Arc::clone(&hub);
thread::Builder::new()
.name("sse-accept".into())
.stack_size(64 * 1024)
.spawn(move || {
handle_sse_connection(hub, stream, guard);
})
.ok();
}
}
fn handle_sse_connection(hub: Arc<SseHub>, mut stream: TcpStream, guard: IpConnGuard) {
let mut buf = [0u8; 2048];
let _ = std::io::Read::read(&mut stream, &mut buf);
stream.set_nodelay(true).ok();
let headers = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Cache-Control: no-cache\r\n\
Connection: keep-alive\r\n\
Access-Control-Allow-Origin: *\r\n\
X-Content-Type-Options: nosniff\r\n\
\r\n";
if stream.write_all(headers.as_bytes()).is_err() {
return;
}
if stream.write_all(b": connected\n\n").is_err() {
return;
}
let _ = stream.flush();
hub.add_client(stream, Some(guard));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hub_starts_with_correct_shard_count() {
let hub = SseHub::new();
assert_eq!(hub.shards.len(), NUM_SHARDS);
assert_eq!(hub.broadcast_txs.len(), NUM_SHARDS);
}
#[test]
fn hub_starts_empty() {
let hub = SseHub::new();
assert_eq!(hub.client_count(), 0);
}
#[test]
fn broadcast_on_empty_hub_does_not_panic() {
let hub = SseHub::new();
hub.broadcast_message("hello");
thread::sleep(Duration::from_millis(50));
assert_eq!(hub.client_count(), 0);
}
#[test]
fn keepalive_on_empty_shard_does_not_panic() {
let shard = SseShard::new();
shard.keepalive();
assert_eq!(shard.count(), 0);
}
#[test]
fn broadcast_on_empty_shard_returns_no_dead() {
let shard = SseShard::new();
let dead = shard.broadcast("test");
assert!(dead.is_empty());
}
#[test]
fn client_ids_are_sequential() {
let hub = SseHub::new();
let mut next_id = hub.next_id.lock().unwrap();
assert_eq!(*next_id, 0);
*next_id = 5;
drop(next_id);
}
}