use std::collections::{HashMap, HashSet};
use std::net::{TcpListener, TcpStream};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use pylon_auth::SessionStore;
use pylon_sync::ChangeEvent;
use tungstenite::handshake::server::{ErrorResponse, Request, Response};
use tungstenite::{accept_hdr_with_config, protocol::WebSocketConfig, Message, WebSocket};
use crate::ip_limit::IpConnCounter;
#[derive(Default)]
struct SubsState {
by_row: HashMap<(String, String), HashSet<u64>>,
by_client: HashMap<u64, HashSet<(String, String)>>,
}
pub struct CrdtSubscriptions {
state: Mutex<SubsState>,
}
impl Default for CrdtSubscriptions {
fn default() -> Self {
Self {
state: Mutex::new(SubsState::default()),
}
}
}
impl CrdtSubscriptions {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn subscribe(&self, client_id: u64, entity: &str, row_id: &str) {
let key = (entity.to_string(), row_id.to_string());
let mut state = self.state.lock().unwrap();
state
.by_row
.entry(key.clone())
.or_default()
.insert(client_id);
state.by_client.entry(client_id).or_default().insert(key);
}
pub fn unsubscribe(&self, client_id: u64, entity: &str, row_id: &str) {
let key = (entity.to_string(), row_id.to_string());
let mut state = self.state.lock().unwrap();
if let Some(set) = state.by_row.get_mut(&key) {
set.remove(&client_id);
if set.is_empty() {
state.by_row.remove(&key);
}
}
if let Some(set) = state.by_client.get_mut(&client_id) {
set.remove(&key);
if set.is_empty() {
state.by_client.remove(&client_id);
}
}
}
pub fn unsubscribe_all(&self, client_id: u64) {
let mut state = self.state.lock().unwrap();
let rows: Vec<(String, String)> = state
.by_client
.remove(&client_id)
.map(|set| set.into_iter().collect())
.unwrap_or_default();
for key in rows {
if let Some(set) = state.by_row.get_mut(&key) {
set.remove(&client_id);
if set.is_empty() {
state.by_row.remove(&key);
}
}
}
}
pub fn subscribers(&self, entity: &str, row_id: &str) -> Vec<u64> {
let key = (entity.to_string(), row_id.to_string());
let state = self.state.lock().unwrap();
state
.by_row
.get(&key)
.map(|set| set.iter().copied().collect())
.unwrap_or_default()
}
pub fn total_subscriptions(&self) -> usize {
self.state
.lock()
.unwrap()
.by_row
.values()
.map(|s| s.len())
.sum()
}
}
const NUM_SHARDS: usize = 16;
const BROADCAST_QUEUE_DEPTH: usize = 1024;
const WS_READ_TIMEOUT: Duration = Duration::from_millis(200);
type ClientSocket = Arc<Mutex<WebSocket<TcpStream>>>;
struct Shard {
clients: Mutex<HashMap<u64, ClientSocket>>,
}
impl Shard {
fn new() -> Self {
Self {
clients: Mutex::new(HashMap::new()),
}
}
fn add(&self, id: u64, ws: WebSocket<TcpStream>) -> ClientSocket {
let handle = Arc::new(Mutex::new(ws));
self.clients.lock().unwrap().insert(id, Arc::clone(&handle));
handle
}
fn remove(&self, id: u64) {
self.clients.lock().unwrap().remove(&id);
}
fn broadcast(&self, msg: &Arc<str>) {
let handles: Vec<(u64, ClientSocket)> = {
let clients = self.clients.lock().unwrap();
clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
};
let mut dead: Vec<u64> = Vec::new();
for (id, handle) in handles {
let mut guard = match handle.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if guard.send(Message::Text((**msg).to_string())).is_err() {
dead.push(id);
}
}
if !dead.is_empty() {
let mut clients = self.clients.lock().unwrap();
for id in &dead {
clients.remove(id);
}
}
}
fn send_binary_to(&self, ids: &[u64], msg: &Arc<[u8]>) -> Vec<u64> {
let handles: Vec<(u64, ClientSocket)> = {
let clients = self.clients.lock().unwrap();
ids.iter()
.filter_map(|id| clients.get(id).map(|h| (*id, Arc::clone(h))))
.collect()
};
let mut dead: Vec<u64> = Vec::new();
for (id, handle) in handles {
let mut guard = match handle.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if guard.send(Message::Binary(msg.to_vec())).is_err() {
dead.push(id);
}
}
if !dead.is_empty() {
let mut clients = self.clients.lock().unwrap();
for id in &dead {
clients.remove(id);
}
}
dead
}
fn broadcast_binary(&self, msg: &Arc<[u8]>) {
let handles: Vec<(u64, ClientSocket)> = {
let clients = self.clients.lock().unwrap();
clients.iter().map(|(id, h)| (*id, Arc::clone(h))).collect()
};
let mut dead: Vec<u64> = Vec::new();
for (id, handle) in handles {
let mut guard = match handle.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if guard.send(Message::Binary(msg.to_vec())).is_err() {
dead.push(id);
}
}
if !dead.is_empty() {
let mut clients = self.clients.lock().unwrap();
for id in &dead {
clients.remove(id);
}
}
}
fn count(&self) -> usize {
self.clients.lock().unwrap().len()
}
}
pub struct WsHub {
shards: Vec<Arc<Shard>>,
next_id: Mutex<u64>,
broadcast_txs: Vec<mpsc::SyncSender<Arc<str>>>,
#[allow(dead_code)]
queue_depth: usize,
subscriptions: Arc<CrdtSubscriptions>,
}
impl WsHub {
pub fn new() -> Arc<Self> {
let mut shards = Vec::with_capacity(NUM_SHARDS);
let mut broadcast_txs = Vec::with_capacity(NUM_SHARDS);
for i in 0..NUM_SHARDS {
let shard = Arc::new(Shard::new());
let (tx, rx) = mpsc::sync_channel::<Arc<str>>(BROADCAST_QUEUE_DEPTH);
let shard_clone = Arc::clone(&shard);
thread::Builder::new()
.name(format!("ws-broadcast-{i}"))
.spawn(move || {
while let Ok(msg) = rx.recv() {
shard_clone.broadcast(&msg);
}
})
.expect("Failed to spawn broadcast worker");
shards.push(shard);
broadcast_txs.push(tx);
}
Arc::new(Self {
shards,
next_id: Mutex::new(0),
broadcast_txs,
queue_depth: BROADCAST_QUEUE_DEPTH,
subscriptions: CrdtSubscriptions::new(),
})
}
pub fn subscriptions(&self) -> &Arc<CrdtSubscriptions> {
&self.subscriptions
}
pub fn broadcast(&self, event: &ChangeEvent) {
let json = match serde_json::to_string(event) {
Ok(j) => j,
Err(_) => return,
};
let shared: Arc<str> = Arc::from(json.into_boxed_str());
self.broadcast_shared(shared);
}
pub fn broadcast_presence(&self, msg: &str) {
let shared: Arc<str> = Arc::from(msg.to_string().into_boxed_str());
self.broadcast_shared(shared);
}
pub fn broadcast_binary(&self, bytes: Vec<u8>) {
let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
for shard in &self.shards {
shard.broadcast_binary(&shared);
}
}
pub fn broadcast_binary_to(&self, client_ids: &[u64], bytes: Vec<u8>) {
if client_ids.is_empty() {
return;
}
let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
let mut by_shard: Vec<Vec<u64>> = (0..NUM_SHARDS).map(|_| Vec::new()).collect();
for id in client_ids {
by_shard[(*id as usize) % NUM_SHARDS].push(*id);
}
for (idx, ids) in by_shard.iter().enumerate() {
if ids.is_empty() {
continue;
}
for dead_id in self.shards[idx].send_binary_to(ids, &shared) {
self.subscriptions.unsubscribe_all(dead_id);
}
}
}
pub fn send_binary_to_one(&self, client_id: u64, bytes: Vec<u8>) {
let shared: Arc<[u8]> = Arc::from(bytes.into_boxed_slice());
let shard_idx = (client_id as usize) % NUM_SHARDS;
for dead_id in self.shards[shard_idx].send_binary_to(&[client_id], &shared) {
self.subscriptions.unsubscribe_all(dead_id);
}
}
fn broadcast_shared(&self, msg: Arc<str>) {
for tx in &self.broadcast_txs {
match tx.try_send(Arc::clone(&msg)) {
Ok(()) => {}
Err(mpsc::TrySendError::Full(_)) => {
tracing::warn!("[ws] broadcast queue full — dropping event for one shard");
}
Err(mpsc::TrySendError::Disconnected(_)) => {
}
}
}
}
fn add_client(&self, ws: WebSocket<TcpStream>) -> (u64, ClientSocket) {
let mut next_id = self.next_id.lock().unwrap();
let id = *next_id;
*next_id += 1;
let shard_idx = (id as usize) % NUM_SHARDS;
let handle = self.shards[shard_idx].add(id, ws);
(id, handle)
}
fn remove_client(&self, id: u64) {
let shard_idx = (id as usize) % NUM_SHARDS;
self.shards[shard_idx].remove(id);
}
pub fn client_count(&self) -> usize {
self.shards.iter().map(|s| s.count()).sum()
}
}
pub type SnapshotFetcher =
Arc<dyn Fn(&pylon_auth::AuthContext, &str, &str) -> Option<Vec<u8>> + Send + Sync>;
pub fn start_ws_server(
hub: Arc<WsHub>,
sessions: Arc<SessionStore>,
port: u16,
snapshot_fetcher: Option<SnapshotFetcher>,
) {
let addr = format!("0.0.0.0:{port}");
let listener = match TcpListener::bind(&addr) {
Ok(l) => l,
Err(e) => {
tracing::warn!("[ws] Failed to bind on {addr}: {e}");
return;
}
};
tracing::warn!(
"[ws] WebSocket server listening on ws://localhost:{port} (sharded, {NUM_SHARDS} shards)"
);
let ip_counter = Arc::new(IpConnCounter::default());
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(_) => continue,
};
let ip = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(_) => continue,
};
let guard = match ip_counter.acquire(ip) {
Some(g) => g,
None => {
continue;
}
};
let hub = Arc::clone(&hub);
let sessions = Arc::clone(&sessions);
let fetcher = snapshot_fetcher.clone();
let spawn_result = thread::Builder::new()
.name("ws-client".into())
.stack_size(64 * 1024)
.spawn(move || {
let _conn_slot = guard;
handle_ws_connection(hub, sessions, stream, fetcher);
});
if spawn_result.is_err() {
}
}
}
fn handle_ws_connection(
hub: Arc<WsHub>,
sessions: Arc<SessionStore>,
stream: TcpStream,
snapshot_fetcher: Option<SnapshotFetcher>,
) {
stream.set_read_timeout(Some(WS_READ_TIMEOUT)).ok();
stream.set_write_timeout(Some(WS_READ_TIMEOUT)).ok();
let token_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let slot_for_cb = Arc::clone(&token_slot);
let max_frame: usize = std::env::var("PYLON_WS_MAX_FRAME")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(16 * 1024 * 1024);
let ws_config = WebSocketConfig {
max_message_size: Some(max_frame),
max_frame_size: Some(max_frame),
..Default::default()
};
let ws = match accept_hdr_with_config(
stream,
move |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
let mut chosen_protocol: Option<String> = None;
let mut auth: Option<String> = None;
for (name, value) in req.headers() {
let lower = name.as_str().to_ascii_lowercase();
if lower == "authorization" {
if let Ok(v) = value.to_str() {
if let Some(tok) = v.strip_prefix("Bearer ") {
auth = Some(tok.to_string());
}
}
} else if lower == "sec-websocket-protocol" {
if let Ok(v) = value.to_str() {
for proto in v.split(',').map(str::trim) {
if let Some(encoded) = proto.strip_prefix("bearer.") {
if let Some(decoded) = percent_decode_token(encoded) {
auth = auth.or(Some(decoded));
chosen_protocol = Some(proto.to_string());
break;
}
}
}
}
}
}
if let Some(chosen) = chosen_protocol {
if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
}
}
*slot_for_cb.lock().unwrap() = auth;
Ok(resp)
},
Some(ws_config),
) {
Ok(ws) => ws,
Err(_) => return,
};
let token = token_slot.lock().unwrap().clone();
let auth_ctx = sessions.resolve(token.as_deref());
if auth_ctx.user_id.is_none() && !auth_ctx.is_admin {
let mut ws = ws;
let _ = ws.close(Some(tungstenite::protocol::CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::Policy,
reason: "unauthorized: bearer token required".into(),
}));
return;
}
let (client_id, socket_handle) = hub.add_client(ws);
loop {
let msg = {
let mut guard = match socket_handle.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
guard.read()
};
match msg {
Ok(Message::Text(text)) => {
let parsed: serde_json::Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
let kind = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
match kind {
"presence" | "topic" => {
let mut stamped = parsed.clone();
if let Some(obj) = stamped.as_object_mut() {
let from = auth_ctx
.user_id
.clone()
.unwrap_or_else(|| "admin".to_string());
obj.insert("from".into(), serde_json::Value::String(from));
}
hub.broadcast_presence(&stamped.to_string());
}
"crdt-subscribe" | "crdt-unsubscribe" => handle_crdt_control(
&hub,
client_id,
&auth_ctx,
kind,
&parsed,
snapshot_fetcher.as_ref(),
),
_ => {}
}
}
Ok(Message::Ping(data)) => {
if let Ok(mut guard) = socket_handle.lock() {
let _ = guard.send(Message::Pong(data));
}
}
Ok(Message::Close(_)) => {
hub.subscriptions.unsubscribe_all(client_id);
hub.remove_client(client_id);
let disconnect = serde_json::json!({
"type": "presence",
"event": "disconnect",
"clientId": client_id,
});
hub.broadcast_presence(&disconnect.to_string());
break;
}
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock
|| io_err.kind() == std::io::ErrorKind::TimedOut =>
{
std::thread::sleep(std::time::Duration::from_millis(1));
continue;
}
Err(_) => {
hub.subscriptions.unsubscribe_all(client_id);
hub.remove_client(client_id);
let disconnect = serde_json::json!({
"type": "presence",
"event": "disconnect",
"clientId": client_id,
});
hub.broadcast_presence(&disconnect.to_string());
break;
}
_ => {}
}
}
}
fn handle_crdt_control(
hub: &Arc<WsHub>,
client_id: u64,
auth_ctx: &pylon_auth::AuthContext,
kind: &str,
parsed: &serde_json::Value,
snapshot_fetcher: Option<&SnapshotFetcher>,
) {
let entity = match parsed.get("entity").and_then(|v| v.as_str()) {
Some(e) if !e.is_empty() => e,
_ => return,
};
let row_id = match parsed
.get("rowId")
.or_else(|| parsed.get("row_id"))
.and_then(|v| v.as_str())
{
Some(r) if !r.is_empty() => r,
_ => return,
};
match kind {
"crdt-subscribe" => {
let snapshot = snapshot_fetcher.and_then(|f| f(auth_ctx, entity, row_id));
let allow_subscribe = snapshot_fetcher.is_none() || snapshot.is_some();
if allow_subscribe {
hub.subscriptions.subscribe(client_id, entity, row_id);
if let Some(bytes) = snapshot {
hub.send_binary_to_one(client_id, bytes);
}
}
}
"crdt-unsubscribe" => {
hub.subscriptions.unsubscribe(client_id, entity, row_id);
}
_ => {}
}
}
fn percent_decode_token(s: &str) -> Option<String> {
let bytes = s.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'%' => {
if i + 2 >= bytes.len() {
return None;
}
let hi = (bytes[i + 1] as char).to_digit(16)?;
let lo = (bytes[i + 2] as char).to_digit(16)?;
out.push(((hi << 4) | lo) as u8);
i += 3;
}
b'+' => {
out.push(b' ');
i += 1;
}
b => {
out.push(b);
i += 1;
}
}
}
String::from_utf8(out).ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shard_count_starts_at_zero() {
let shard = Shard::new();
assert_eq!(shard.count(), 0);
}
#[test]
fn hub_starts_with_zero_clients() {
let hub = WsHub::new();
assert_eq!(hub.client_count(), 0);
}
#[test]
fn broadcast_to_empty_hub_doesnt_panic() {
let hub = WsHub::new();
let event = ChangeEvent {
seq: 1,
entity: "Test".into(),
row_id: "1".into(),
kind: pylon_sync::ChangeKind::Insert,
data: None,
timestamp: String::new(),
};
hub.broadcast(&event);
hub.broadcast_presence("test");
}
#[test]
fn num_shards_is_power_of_two() {
assert!(
NUM_SHARDS.is_power_of_two(),
"NUM_SHARDS ({NUM_SHARDS}) must be a power of two for even distribution"
);
}
#[test]
fn crdt_subscriptions_subscribe_dedups() {
let subs = CrdtSubscriptions::default();
subs.subscribe(1, "Channel", "abc");
subs.subscribe(1, "Channel", "abc");
assert_eq!(subs.subscribers("Channel", "abc"), vec![1]);
assert_eq!(subs.total_subscriptions(), 1);
}
#[test]
fn crdt_subscriptions_returns_all_subscribers() {
let subs = CrdtSubscriptions::default();
subs.subscribe(1, "Channel", "abc");
subs.subscribe(2, "Channel", "abc");
subs.subscribe(3, "Channel", "abc");
let mut ids = subs.subscribers("Channel", "abc");
ids.sort();
assert_eq!(ids, vec![1, 2, 3]);
}
#[test]
fn crdt_subscriptions_unsubscribe_cleans_empty_rows() {
let subs = CrdtSubscriptions::default();
subs.subscribe(1, "Channel", "abc");
subs.unsubscribe(1, "Channel", "abc");
assert!(subs.subscribers("Channel", "abc").is_empty());
assert_eq!(subs.total_subscriptions(), 0);
}
#[test]
fn crdt_subscriptions_unsubscribe_all_drops_every_row() {
let subs = CrdtSubscriptions::default();
subs.subscribe(1, "Channel", "a");
subs.subscribe(1, "Channel", "b");
subs.subscribe(1, "Message", "m1");
subs.subscribe(2, "Channel", "a"); subs.unsubscribe_all(1);
assert!(subs.subscribers("Channel", "b").is_empty());
assert!(subs.subscribers("Message", "m1").is_empty());
assert_eq!(subs.subscribers("Channel", "a"), vec![2]);
}
#[test]
fn crdt_subscriptions_unsubscribe_unknown_client_is_noop() {
let subs = CrdtSubscriptions::default();
subs.unsubscribe(99, "Channel", "abc");
subs.unsubscribe_all(99);
assert_eq!(subs.total_subscriptions(), 0);
}
#[test]
fn crdt_subscriptions_concurrent_subscribe_and_unsubscribe() {
let subs = Arc::new(CrdtSubscriptions::default());
let mut handles = Vec::new();
for client_id in 0..16u64 {
let subs = Arc::clone(&subs);
handles.push(std::thread::spawn(move || {
for i in 0..200 {
let row = format!("row-{i}");
subs.subscribe(client_id, "Channel", &row);
subs.unsubscribe(client_id, "Channel", &row);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(subs.total_subscriptions(), 0);
}
#[test]
fn crdt_subscriptions_unsubscribe_all_after_concurrent_subscribes() {
let subs = Arc::new(CrdtSubscriptions::default());
let mut handles = Vec::new();
for client_id in 0..8u64 {
let subs = Arc::clone(&subs);
handles.push(std::thread::spawn(move || {
for i in 0..100 {
let row = format!("row-{i}");
subs.subscribe(client_id, "Channel", &row);
}
}));
}
for h in handles {
h.join().unwrap();
}
for client_id in 0..8u64 {
subs.unsubscribe_all(client_id);
}
assert_eq!(subs.total_subscriptions(), 0);
}
#[test]
fn shard_assignment_distributes_evenly() {
let mut counts = vec![0usize; NUM_SHARDS];
for id in 0..(NUM_SHARDS as u64 * 100) {
counts[(id as usize) % NUM_SHARDS] += 1;
}
for (i, count) in counts.iter().enumerate() {
assert_eq!(*count, 100, "Shard {i} got {count} clients, expected 100");
}
}
}