use std::sync::atomic::{AtomicUsize, Ordering};
use dashmap::DashMap;
use tokio::sync::{mpsc, oneshot};
pub type ConnectionId = String;
#[derive(Debug, Clone)]
pub struct CloseSignal {
pub code: u16,
pub reason: String,
}
#[derive(Debug, Clone)]
pub struct ConnectionState {
pub connection_id: ConnectionId,
pub user_id: String,
pub context_hash: u64,
pub expires_at: i64,
}
impl ConnectionState {
#[must_use]
pub const fn new(
connection_id: ConnectionId,
user_id: String,
context_hash: u64,
expires_at: i64,
) -> Self {
Self {
connection_id,
user_id,
context_hash,
expires_at,
}
}
}
pub struct ConnectionManager {
connections: DashMap<ConnectionId, ConnectionState>,
event_senders: DashMap<ConnectionId, mpsc::Sender<String>>,
close_senders: DashMap<ConnectionId, oneshot::Sender<CloseSignal>>,
drop_counts: DashMap<ConnectionId, AtomicUsize>,
max_consecutive_drops: usize,
connection_event_capacity: usize,
}
impl ConnectionManager {
#[must_use]
pub fn new(max_consecutive_drops: usize, connection_event_capacity: usize) -> Self {
Self {
connections: DashMap::new(),
event_senders: DashMap::new(),
close_senders: DashMap::new(),
drop_counts: DashMap::new(),
max_consecutive_drops,
connection_event_capacity,
}
}
#[must_use]
pub fn insert(
&self,
state: ConnectionState,
) -> (mpsc::Receiver<String>, oneshot::Receiver<CloseSignal>) {
let (event_tx, event_rx) = mpsc::channel(self.connection_event_capacity);
let (close_tx, close_rx) = oneshot::channel();
self.event_senders.insert(state.connection_id.clone(), event_tx);
self.close_senders.insert(state.connection_id.clone(), close_tx);
self.drop_counts.insert(state.connection_id.clone(), AtomicUsize::new(0));
self.connections.insert(state.connection_id.clone(), state);
(event_rx, close_rx)
}
pub fn remove(&self, connection_id: &str) {
self.connections.remove(connection_id);
self.event_senders.remove(connection_id);
self.close_senders.remove(connection_id);
self.drop_counts.remove(connection_id);
}
#[must_use]
pub fn count(&self) -> usize {
self.connections.len()
}
#[must_use]
pub fn count_by_context(&self, context_hash: u64) -> usize {
self.connections
.iter()
.filter(|entry| entry.value().context_hash == context_hash)
.count()
}
#[must_use]
pub fn send_event(&self, connection_id: &str, json: String) -> bool {
let sent = self
.event_senders
.get(connection_id)
.is_some_and(|sender| sender.try_send(json).is_ok());
if sent {
if let Some(counter) = self.drop_counts.get(connection_id) {
counter.store(0, Ordering::Relaxed);
}
} else if let Some(counter) = self.drop_counts.get(connection_id) {
let new_count = counter.fetch_add(1, Ordering::Relaxed) + 1;
if new_count >= self.max_consecutive_drops {
if let Some((_, close_tx)) = self.close_senders.remove(connection_id) {
let _ = close_tx.send(CloseSignal {
code: 4002,
reason: "slow consumer".to_owned(),
});
}
}
}
sent
}
#[must_use]
pub fn drop_count(&self, connection_id: &str) -> usize {
self.drop_counts.get(connection_id).map_or(0, |c| c.load(Ordering::Relaxed))
}
}
impl Default for ConnectionManager {
fn default() -> Self {
Self::new(50, 256)
}
}