use crate::error::{Result, TidewayError};
use super::connection::Connection;
use super::message::Message;
use serde::Serialize;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use dashmap::DashMap;
pub type ConnectionHandle = Arc<tokio::sync::RwLock<Connection>>;
#[derive(Clone)]
pub struct ConnectionManager {
connections: Arc<DashMap<String, ConnectionHandle>>,
rooms: Arc<DashMap<String, HashSet<String>>>,
users: Arc<DashMap<String, HashSet<String>>>,
max_connections: usize,
total_connections: Arc<AtomicU64>,
total_broadcasts: Arc<AtomicU64>,
}
impl ConnectionManager {
pub fn new() -> Self {
Self::with_max_connections(0)
}
pub fn with_max_connections(max_connections: usize) -> Self {
Self {
connections: Arc::new(DashMap::new()),
rooms: Arc::new(DashMap::new()),
users: Arc::new(DashMap::new()),
max_connections,
total_connections: Arc::new(AtomicU64::new(0)),
total_broadcasts: Arc::new(AtomicU64::new(0)),
}
}
pub async fn register(&self, conn: ConnectionHandle) -> Result<()> {
if self.max_connections > 0 && self.connections.len() >= self.max_connections {
return Err(TidewayError::service_unavailable(format!(
"Maximum connection limit ({}) reached",
self.max_connections
)));
}
let conn_id = {
let conn_guard = conn.read().await;
let id = conn_guard.id().to_string();
if let Some(user_id) = conn_guard.user_id() {
self.users
.entry(user_id.to_string())
.or_insert_with(HashSet::new)
.insert(id.clone());
}
id
};
self.connections.insert(conn_id.clone(), conn);
self.total_connections.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub async fn unregister(&self, conn_id: &str) {
if let Some((_, conn)) = self.connections.remove(conn_id) {
let conn_guard = conn.read().await;
if let Some(user_id) = conn_guard.user_id() {
if let Some(mut user_conns) = self.users.get_mut(user_id) {
user_conns.remove(conn_id);
if user_conns.is_empty() {
drop(user_conns);
self.users.remove(user_id);
}
}
}
for room_name in conn_guard.rooms() {
if let Some(mut room_conns) = self.rooms.get_mut(room_name) {
room_conns.remove(conn_id);
if room_conns.is_empty() {
drop(room_conns);
self.rooms.remove(room_name);
}
}
}
}
}
pub fn get(&self, conn_id: &str) -> Option<ConnectionHandle> {
self.connections.get(conn_id).map(|entry| entry.clone())
}
pub async fn broadcast(&self, msg: Message) -> Result<()> {
self.total_broadcasts.fetch_add(1, Ordering::Relaxed);
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
let conns: Vec<(String, ConnectionHandle)> = self.connections
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
for (conn_id, conn) in conns {
if let Ok(conn_guard) = conn.try_read() {
if let Err(e) = conn_guard.send(msg.clone()).await {
errors.push(e);
failed_conns.push(conn_id);
}
} else {
failed_conns.push(conn_id);
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections",
errors.len()
)))
} else {
Ok(())
}
}
pub async fn broadcast_text(&self, text: impl Into<String>) -> Result<()> {
self.broadcast(Message::Text(text.into())).await
}
pub async fn broadcast_json<T: Serialize>(&self, data: &T) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text(json).await
}
pub async fn broadcast_to_room(&self, room: &str, msg: Message) -> Result<()> {
let conn_ids: Vec<String> = self
.rooms
.get(room)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default();
if conn_ids.is_empty() {
return Ok(());
}
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
for conn_id in conn_ids {
if let Some(conn) = self.connections.get(&conn_id) {
if let Ok(conn_guard) = conn.try_read() {
if let Err(e) = conn_guard.send(msg.clone()).await {
errors.push(e);
failed_conns.push(conn_id);
}
} else {
failed_conns.push(conn_id);
}
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections in room {}",
errors.len(),
room
)))
} else {
Ok(())
}
}
pub async fn broadcast_text_to_room(&self, room: &str, text: impl Into<String>) -> Result<()> {
self.broadcast_to_room(room, Message::Text(text.into())).await
}
pub async fn broadcast_json_to_room<T: Serialize>(&self, room: &str, data: &T) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text_to_room(room, json).await
}
pub async fn broadcast_to_user(&self, user_id: &str, msg: Message) -> Result<()> {
let conn_ids: Vec<String> = self
.users
.get(user_id)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default();
if conn_ids.is_empty() {
return Ok(());
}
let mut errors = Vec::new();
let mut failed_conns = Vec::new();
for conn_id in conn_ids {
if let Some(conn) = self.connections.get(&conn_id) {
if let Ok(conn_guard) = conn.try_read() {
if let Err(e) = conn_guard.send(msg.clone()).await {
errors.push(e);
failed_conns.push(conn_id);
}
} else {
failed_conns.push(conn_id);
}
}
}
for conn_id in failed_conns {
let _ = self.unregister(&conn_id).await;
}
if !errors.is_empty() {
Err(TidewayError::internal(format!(
"Failed to send to {} connections for user {}",
errors.len(),
user_id
)))
} else {
Ok(())
}
}
pub async fn broadcast_text_to_user(&self, user_id: &str, text: impl Into<String>) -> Result<()> {
self.broadcast_to_user(user_id, Message::Text(text.into())).await
}
pub async fn broadcast_json_to_user<T: Serialize>(&self, user_id: &str, data: &T) -> Result<()> {
let json = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize JSON: {}", e)))?;
self.broadcast_text_to_user(user_id, json).await
}
pub fn add_to_room(&self, conn_id: &str, room: &str) {
self.rooms
.entry(room.to_string())
.or_insert_with(HashSet::new)
.insert(conn_id.to_string());
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(mut conn_guard) = conn.try_write() {
conn_guard.join_room(room);
}
}
}
pub fn remove_from_room(&self, conn_id: &str, room: &str) {
if let Some(mut room_conns) = self.rooms.get_mut(room) {
room_conns.remove(conn_id);
if room_conns.is_empty() {
drop(room_conns);
self.rooms.remove(room);
}
}
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(mut conn_guard) = conn.try_write() {
conn_guard.leave_room(room);
}
}
}
pub fn room_members(&self, room: &str) -> Vec<String> {
self.rooms
.get(room)
.map(|entry| entry.value().iter().cloned().collect())
.unwrap_or_default()
}
pub fn room_count(&self) -> usize {
self.rooms.len()
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn max_connections(&self) -> usize {
self.max_connections
}
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
pub fn total_broadcasts(&self) -> u64 {
self.total_broadcasts.load(Ordering::Relaxed)
}
pub fn metrics(&self) -> ConnectionMetrics {
ConnectionMetrics {
active_connections: self.connection_count(),
max_connections: self.max_connections,
total_connections: self.total_connections(),
total_broadcasts: self.total_broadcasts(),
room_count: self.room_count(),
}
}
pub fn update_user_mapping(&self, conn_id: &str) {
if let Some(conn) = self.connections.get(conn_id) {
if let Ok(conn_guard) = conn.try_read() {
if let Some(user_id) = conn_guard.user_id() {
self.users
.entry(user_id.to_string())
.or_insert_with(HashSet::new)
.insert(conn_id.to_string());
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionMetrics {
pub active_connections: usize,
pub max_connections: usize,
pub total_connections: u64,
pub total_broadcasts: u64,
pub room_count: usize,
}
impl Default for ConnectionManager {
fn default() -> Self {
Self::new()
}
}