use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::{DashMap, DashSet};
use serde::Serialize;
use tokio::sync::mpsc;
use crate::web::context::Claims;
pub type ConnId = u64;
#[derive(Clone, Debug)]
pub enum WsMessage {
Text(Arc<str>),
Ping,
Close,
}
struct ConnEntry {
tx: mpsc::Sender<WsMessage>,
member_rooms: DashSet<String>,
last_seen: AtomicU64,
#[allow(dead_code)]
claims: Option<Arc<Claims>>,
}
#[inline]
fn unix_now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub struct ConnectionRegistry {
next_id: AtomicU64,
conns: DashMap<ConnId, ConnEntry>,
rooms: DashMap<String, DashSet<ConnId>>,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self {
next_id: AtomicU64::new(1),
conns: DashMap::new(),
rooms: DashMap::new(),
}
}
pub fn register(&self, tx: mpsc::Sender<WsMessage>, claims: Option<Arc<Claims>>) -> ConnId {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
self.conns.insert(
id,
ConnEntry {
tx,
member_rooms: DashSet::new(),
last_seen: AtomicU64::new(unix_now()),
claims,
},
);
metrics::gauge!("ws_connections").increment(1.0);
id
}
#[inline]
pub fn touch(&self, id: ConnId) {
if let Some(entry) = self.conns.get(&id) {
entry.last_seen.store(unix_now(), Ordering::Relaxed);
}
}
fn enqueue_or_evict(&self, id: ConnId, entry: &ConnEntry, msg: WsMessage) -> bool {
match entry.tx.try_send(msg) {
Ok(()) => {
metrics::counter!("ws_messages_out_total").increment(1);
true
}
Err(mpsc::error::TrySendError::Full(_)) => {
metrics::counter!("ws_slow_client_evictions_total").increment(1);
tracing::warn!(conn = id, "WS outbound queue full — evicting slow client");
false
}
Err(mpsc::error::TrySendError::Closed(_)) => true,
}
}
pub fn close_all(&self) -> usize {
let mut n = 0;
for entry in self.conns.iter() {
let _ = entry.value().tx.try_send(WsMessage::Close);
n += 1;
}
n
}
pub fn sweep_idle(&self, max_idle_secs: u64) -> Vec<ConnId> {
let now = unix_now();
let stale: Vec<ConnId> = self
.conns
.iter()
.filter(|e| {
now.saturating_sub(e.value().last_seen.load(Ordering::Relaxed)) > max_idle_secs
})
.map(|e| *e.key())
.collect();
for id in &stale {
metrics::counter!("ws_idle_reaped_total").increment(1);
if let Some(entry) = self.conns.get(id) {
let _ = entry.tx.try_send(WsMessage::Close);
}
self.unregister(*id);
}
stale
}
pub fn unregister(&self, id: ConnId) {
if let Some((_, entry)) = self.conns.remove(&id) {
metrics::gauge!("ws_connections").decrement(1.0);
for room in entry.member_rooms.iter() {
let key = room.key();
if let Some(set) = self.rooms.get(key) {
set.remove(&id);
if set.is_empty() {
drop(set);
self.rooms.remove_if(key, |_, s| s.is_empty());
}
}
}
}
}
#[inline]
pub fn connection_count(&self) -> usize {
self.conns.len()
}
pub fn join_room(&self, id: ConnId, room: String) {
if let Some(entry) = self.conns.get(&id) {
entry.member_rooms.insert(room.clone());
self.rooms.entry(room).or_default().insert(id);
}
}
pub fn leave_room(&self, id: ConnId, room: &str) {
if let Some(entry) = self.conns.get(&id) {
entry.member_rooms.remove(room);
if let Some(set) = self.rooms.get(room) {
set.remove(&id);
if set.is_empty() {
drop(set);
self.rooms.remove_if(room, |_, s| s.is_empty());
}
}
}
}
#[inline]
pub fn send_text(&self, id: ConnId, text: Arc<str>) {
let evict = match self.conns.get(&id) {
Some(entry) => !self.enqueue_or_evict(id, &entry, WsMessage::Text(text)),
None => false,
}; if evict {
self.unregister(id);
}
}
pub fn broadcast_text(&self, text: &str) {
let arc: Arc<str> = Arc::from(text);
let mut evict: Vec<ConnId> = Vec::new();
for entry in self.conns.iter() {
if !self.enqueue_or_evict(
*entry.key(),
entry.value(),
WsMessage::Text(Arc::clone(&arc)),
) {
evict.push(*entry.key());
}
}
for id in evict {
self.unregister(id);
}
}
pub fn broadcast_room_text(&self, room: &str, text: &str) {
let Some(members) = self.rooms.get(room) else {
return;
};
let arc: Arc<str> = Arc::from(text);
let mut evict: Vec<ConnId> = Vec::new();
for id in members.iter() {
if let Some(entry) = self.conns.get(&id) {
if !self.enqueue_or_evict(*id, &entry, WsMessage::Text(Arc::clone(&arc))) {
evict.push(*id);
}
}
}
drop(members); for id in evict {
self.unregister(id);
}
}
}
impl Default for ConnectionRegistry {
fn default() -> Self {
Self::new()
}
}
fn envelope<T: Serialize>(event: &str, payload: &T) -> String {
#[derive(Serialize)]
struct Envelope<'a, P> {
event: &'a str,
data: &'a P,
}
serde_json::to_string(&Envelope {
event,
data: payload,
})
.unwrap_or_else(|_| String::from(r#"{"event":"error","data":null}"#))
}
#[derive(Clone)]
pub struct WsClient {
id: ConnId,
reg: &'static ConnectionRegistry,
claims: Option<Arc<Claims>>,
tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
}
impl WsClient {
#[doc(hidden)]
pub fn __new(
id: ConnId,
reg: &'static ConnectionRegistry,
claims: Option<Arc<Claims>>,
tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
) -> Self {
Self {
id,
reg,
claims,
tenant,
}
}
#[inline]
pub fn id(&self) -> ConnId {
self.id
}
#[inline]
pub fn claims(&self) -> Option<&Claims> {
self.claims.as_deref()
}
#[inline]
pub fn tenant(&self) -> Option<&crate::web::tenant::TenantConfig> {
self.tenant.as_deref()
}
pub async fn emit<T: Serialize>(&self, event: &str, payload: T) {
let text: Arc<str> = envelope(event, &payload).into();
self.reg.send_text(self.id, text);
}
pub async fn broadcast<T: Serialize>(&self, event: &str, payload: T) {
self.reg.broadcast_text(&envelope(event, &payload));
}
pub async fn broadcast_to_room<T: Serialize>(&self, room: &str, event: &str, payload: T) {
self.reg
.broadcast_room_text(room, &envelope(event, &payload));
}
pub fn join_room(&self, room: impl Into<String>) {
self.reg.join_room(self.id, room.into());
}
pub fn leave_room(&self, room: &str) {
self.reg.leave_room(self.id, room);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(test)]
impl ConnectionRegistry {
fn set_last_seen(&self, id: ConnId, unix_secs: u64) {
if let Some(entry) = self.conns.get(&id) {
entry.last_seen.store(unix_secs, Ordering::Relaxed);
}
}
}
fn reg_with_conn(buffer: usize) -> (ConnectionRegistry, ConnId, mpsc::Receiver<WsMessage>) {
let reg = ConnectionRegistry::new();
let (tx, rx) = mpsc::channel(buffer);
let id = reg.register(tx, None);
(reg, id, rx)
}
#[tokio::test]
async fn slow_client_is_evicted_when_queue_fills() {
let (reg, _id, _rx) = reg_with_conn(1);
assert_eq!(reg.connection_count(), 1);
reg.broadcast_text("one");
assert_eq!(reg.connection_count(), 1);
reg.broadcast_text("two");
assert_eq!(reg.connection_count(), 0, "slow client must be evicted");
}
#[tokio::test]
async fn fast_client_receives_broadcasts() {
let (reg, _id, mut rx) = reg_with_conn(8);
reg.broadcast_text("hello");
match rx.recv().await {
Some(WsMessage::Text(t)) => assert_eq!(&*t, "hello"),
other => panic!("expected text frame, got {other:?}"),
}
}
#[tokio::test]
async fn room_membership_and_cleanup() {
let (reg, id, mut rx) = reg_with_conn(8);
reg.join_room(id, "alpha".into());
reg.broadcast_room_text("alpha", "in-room");
assert!(matches!(rx.recv().await, Some(WsMessage::Text(t)) if &*t == "in-room"));
reg.leave_room(id, "alpha");
reg.broadcast_room_text("alpha", "after-leave");
assert!(
rx.try_recv().is_err(),
"must not receive after leaving the room"
);
reg.join_room(id, "beta".into());
reg.unregister(id);
assert_eq!(reg.connection_count(), 0);
reg.broadcast_room_text("beta", "to-nobody"); }
#[tokio::test]
async fn sweep_idle_reaps_only_stale_connections() {
let (reg, stale, _rx1) = reg_with_conn(8);
let (tx2, mut rx2) = mpsc::channel(8);
let fresh = reg.register(tx2, None);
reg.set_last_seen(stale, unix_now() - 3600);
let reaped = reg.sweep_idle(60);
assert_eq!(reaped, vec![stale]);
assert_eq!(reg.connection_count(), 1);
reg.send_text(fresh, Arc::from("still-alive"));
assert!(matches!(rx2.recv().await, Some(WsMessage::Text(t)) if &*t == "still-alive"));
}
#[tokio::test]
async fn close_all_enqueues_close_frames() {
let (reg, _id, mut rx) = reg_with_conn(8);
assert_eq!(reg.close_all(), 1);
assert!(matches!(rx.recv().await, Some(WsMessage::Close)));
}
}