use super::connection::Connection;
use super::message::Message;
use crate::error::{Result, TidewayError};
use dashmap::DashMap;
use serde::Serialize;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
pub type ConnectionHandle = Arc<tokio::sync::RwLock<Connection>>;
#[derive(Clone)]
pub struct ConnectionManager {
connections: Arc<DashMap<String, ConnectionHandle>>,
rooms: Arc<DashMap<String, HashSet<String>>>,
users: Arc<DashMap<String, HashSet<String>>>,
max_connections: usize,
active_count: Arc<AtomicU64>,
total_connections: Arc<AtomicU64>,
total_broadcasts: Arc<AtomicU64>,
}
impl ConnectionManager {
pub fn new() -> Self {
Self::with_max_connections(0)
}
pub fn with_max_connections(max_connections: usize) -> Self {
Self {
connections: Arc::new(DashMap::new()),
rooms: Arc::new(DashMap::new()),
users: Arc::new(DashMap::new()),
max_connections,
active_count: Arc::new(AtomicU64::new(0)),
total_connections: Arc::new(AtomicU64::new(0)),
total_broadcasts: Arc::new(AtomicU64::new(0)),
}
}
pub async fn register(&self, conn: ConnectionHandle) -> Result<()> {
let (conn_id, user_id) = {
let conn_guard = conn.read().await;
(
conn_guard.id().to_string(),
conn_guard.user_id().map(String::from),
)
};
if self.max_connections > 0 {
let max = self.max_connections as u64;
loop {
let current = self.active_count.load(Ordering::Acquire);
if current >= max {
return Err(TidewayError::service_unavailable(format!(
"Maximum connection limit ({}) reached",
self.max_connections
)));
}
match self.active_count.compare_exchange_weak(
current,
current + 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break, Err(_) => continue, }
}
} else {
self.active_count.fetch_add(1, Ordering::Relaxed);
}
self.connections.insert(conn_id.clone(), conn);
if let Some(uid) = user_id {
self.users.entry(uid).or_default().insert(conn_id.clone());
}
self.total_connections.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub async fn unregister(&self, conn_id: &str) {
if let Some((_, conn)) = self.connections.remove(conn_id) {
self.active_count.fetch_sub(1, Ordering::Relaxed);
let conn_guard = conn.read().await;
if let Some(user_id) = conn_guard.user_id() {
if let Some(mut user_conns) = self.users.get_mut(user_id) {
user_conns.remove(conn_id);
if user_conns.is_empty() {
drop(user_conns);
self.users.remove(user_id);
}
}
}
for room_name in conn_guard.rooms() {
if let Some(mut room_conns) = self.rooms.get_mut(room_name) {
room_conns.remove(conn_id);
if room_conns.is_empty() {
drop(room_conns);
self.rooms.remove(room_name);
}
}
}
drop(conn_guard);
self.remove_connection_from_all_rooms(conn_id);
}
}
fn remove_connection_from_all_rooms(&self, conn_id: &str) {
let room_names: Vec<String> = self.rooms.iter().map(|entry| entry.key().clone()).collect();
for room_name in room_names {
if let Some(mut room_conns) = self.rooms.get_mut(&room_name) {
room_conns.remove(conn_id);
if room_conns.is_empty() {
drop(room_conns);
self.rooms.remove(&room_name);
}
}
}
}
pub fn get(&self, conn_id: &str) -> Option<ConnectionHandle> {
self.connections.get(conn_id).map(|entry| entry.clone())
}
pub async fn broadcast(&self, msg: Message) -> Result<()> {
self.total_broadcasts.fetch_add(1, Ordering::Relaxed);
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
let conns: Vec<(String, ConnectionHandle)> = self
.connections
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
for (conn_id, conn) in conns {
let sender = {
let conn_guard = conn.read().await;
conn_guard.sender_clone()
};
if sender.send(msg.clone()).await.is_err() {
errors.push(TidewayError::internal(
"Failed to send message: channel full or connection closed",
));
failed_conns.push(conn_id);
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections",
errors.len()
)))
} else {
Ok(())
}
}
pub async fn broadcast_text(&self, text: impl Into<String>) -> Result<()> {
self.broadcast(Message::Text(text.into())).await
}
pub async fn broadcast_json<T: Serialize>(&self, data: &T) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text(json).await
}
pub async fn broadcast_to_room(&self, room: &str, msg: Message) -> Result<()> {
let conn_ids: Vec<String> = self
.rooms
.get(room)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default();
if conn_ids.is_empty() {
return Ok(());
}
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
for conn_id in conn_ids {
let Some(conn) = self
.connections
.get(&conn_id)
.map(|entry| entry.value().clone())
else {
continue;
};
let sender = {
let conn_guard = conn.read().await;
conn_guard.sender_clone()
};
if sender.send(msg.clone()).await.is_err() {
errors.push(TidewayError::internal(
"Failed to send message: channel full or connection closed",
));
failed_conns.push(conn_id);
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections in room {}",
errors.len(),
room
)))
} else {
Ok(())
}
}
pub async fn broadcast_text_to_room(&self, room: &str, text: impl Into<String>) -> Result<()> {
self.broadcast_to_room(room, Message::Text(text.into()))
.await
}
pub async fn broadcast_json_to_room<T: Serialize>(&self, room: &str, data: &T) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text_to_room(room, json).await
}
pub async fn broadcast_to_user(&self, user_id: &str, msg: Message) -> Result<()> {
let conn_ids: Vec<String> = self
.users
.get(user_id)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default();
if conn_ids.is_empty() {
return Ok(());
}
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
for conn_id in conn_ids {
let Some(conn) = self
.connections
.get(&conn_id)
.map(|entry| entry.value().clone())
else {
continue;
};
let sender = {
let conn_guard = conn.read().await;
conn_guard.sender_clone()
};
if sender.send(msg.clone()).await.is_err() {
errors.push(TidewayError::internal(
"Failed to send message: channel full or connection closed",
));
failed_conns.push(conn_id);
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections for user {}",
errors.len(),
user_id
)))
} else {
Ok(())
}
}
pub async fn broadcast_text_to_user(
&self,
user_id: &str,
text: impl Into<String>,
) -> Result<()> {
self.broadcast_to_user(user_id, Message::Text(text.into()))
.await
}
pub async fn broadcast_json_to_user<T: Serialize>(
&self,
user_id: &str,
data: &T,
) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text_to_user(user_id, json).await
}
pub fn add_to_room(&self, conn_id: &str, room: &str) {
self.rooms
.entry(room.to_string())
.or_default()
.insert(conn_id.to_string());
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(mut conn_guard) = conn.try_write() {
conn_guard.join_room(room);
}
}
}
pub fn remove_from_room(&self, conn_id: &str, room: &str) {
if let Some(mut room_conns) = self.rooms.get_mut(room) {
room_conns.remove(conn_id);
if room_conns.is_empty() {
drop(room_conns);
self.rooms.remove(room);
}
}
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(mut conn_guard) = conn.try_write() {
conn_guard.leave_room(room);
}
}
}
pub fn room_members(&self, room: &str) -> Vec<String> {
self.rooms
.get(room)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default()
}
pub fn room_count(&self) -> usize {
self.rooms.len()
}
pub fn connection_count(&self) -> usize {
self.active_count.load(Ordering::Relaxed) as usize
}
pub fn max_connections(&self) -> usize {
self.max_connections
}
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
pub fn total_broadcasts(&self) -> u64 {
self.total_broadcasts.load(Ordering::Relaxed)
}
pub fn metrics(&self) -> ConnectionMetrics {
ConnectionMetrics {
active_connections: self.connection_count(),
max_connections: self.max_connections,
total_connections: self.total_connections(),
total_broadcasts: self.total_broadcasts(),
room_count: self.room_count(),
}
}
pub fn update_user_mapping(&self, conn_id: &str) {
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(conn_guard) = conn.try_read() {
if let Some(user_id) = conn_guard.user_id() {
self.users
.entry(user_id.to_string())
.or_default()
.insert(conn_id.to_string());
}
}
}
}
pub fn reconcile_counter(&self) -> Option<(u64, u64)> {
let actual_count = self.connections.len() as u64;
let atomic_count = self.active_count.load(Ordering::Acquire);
if actual_count != atomic_count {
tracing::warn!(
actual = actual_count,
atomic = atomic_count,
drift = (atomic_count as i64 - actual_count as i64),
"Connection counter drift detected, correcting"
);
self.active_count.store(actual_count, Ordering::Release);
Some((atomic_count, actual_count))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionMetrics {
pub active_connections: usize,
pub max_connections: usize,
pub total_connections: u64,
pub total_broadcasts: u64,
pub room_count: usize,
}
impl Default for ConnectionManager {
fn default() -> Self {
Self::new()
}
}