use crate::connection::{Connection, ConnectionId};
use crate::error::{WebSocketError, WebSocketResult};
use crate::message::Message;
use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::Arc;
pub type RoomId = String;
#[derive(Debug)]
pub struct Room {
pub id: RoomId,
members: DashMap<ConnectionId, ()>,
}
impl Room {
pub fn new(id: RoomId) -> Self {
Self {
id,
members: DashMap::new(),
}
}
pub fn join(&self, connection_id: ConnectionId) {
self.members.insert(connection_id, ());
}
pub fn leave(&self, connection_id: &str) -> bool {
self.members.remove(connection_id).is_some()
}
pub fn contains(&self, connection_id: &str) -> bool {
self.members.contains_key(connection_id)
}
pub fn len(&self) -> usize {
self.members.len()
}
pub fn is_empty(&self) -> bool {
self.members.is_empty()
}
pub fn members(&self) -> Vec<ConnectionId> {
self.members.iter().map(|r| r.key().clone()).collect()
}
}
pub struct RoomManager {
rooms: DashMap<RoomId, Arc<Room>>,
connection_rooms: DashMap<ConnectionId, HashSet<RoomId>>,
connections: DashMap<ConnectionId, Connection>,
}
impl RoomManager {
pub fn new() -> Self {
Self {
rooms: DashMap::new(),
connection_rooms: DashMap::new(),
connections: DashMap::new(),
}
}
pub fn register_connection(&self, connection: Connection) {
let id = connection.id.clone();
self.connections.insert(id.clone(), connection);
self.connection_rooms.insert(id, HashSet::new());
}
pub fn unregister_connection(&self, connection_id: &str) {
if let Some((_, room_ids)) = self.connection_rooms.remove(connection_id) {
for room_id in room_ids {
if let Some(room) = self.rooms.get(&room_id) {
room.leave(connection_id);
}
self.rooms.remove_if(&room_id, |_, room| room.is_empty());
}
}
self.connections.remove(connection_id);
}
pub fn get_connection(&self, connection_id: &str) -> Option<Connection> {
self.connections.get(connection_id).map(|c| c.clone())
}
pub fn create_room(&self, room_id: RoomId) -> Arc<Room> {
self.rooms
.entry(room_id.clone())
.or_insert_with(|| Arc::new(Room::new(room_id)))
.clone()
}
pub fn get_room(&self, room_id: &str) -> Option<Arc<Room>> {
self.rooms.get(room_id).map(|r| r.clone())
}
pub fn delete_room(&self, room_id: &str) -> bool {
if let Some((_, room)) = self.rooms.remove(room_id) {
for member_id in room.members() {
if let Some(mut rooms) = self.connection_rooms.get_mut(&member_id) {
rooms.remove(room_id);
}
}
true
} else {
false
}
}
pub fn join_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
if !self.connections.contains_key(connection_id) {
return Err(WebSocketError::ConnectionNotFound(connection_id.to_string()));
}
let room = self.create_room(room_id.to_string());
room.join(connection_id.to_string());
if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
rooms.insert(room_id.to_string());
}
Ok(())
}
pub fn leave_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
if let Some(room) = self.rooms.get(room_id) {
room.leave(connection_id);
}
if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
rooms.remove(room_id);
}
self.rooms.remove_if(room_id, |_, room| room.is_empty());
Ok(())
}
pub fn broadcast_to_room(&self, room_id: &str, message: Message) -> WebSocketResult<usize> {
let room = self
.rooms
.get(room_id)
.ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
let mut sent_count = 0;
for member_id in room.members() {
if let Some(conn) = self.connections.get(&member_id) {
if conn.send(message.clone()).is_ok() {
sent_count += 1;
}
}
}
Ok(sent_count)
}
pub fn broadcast_to_room_except(
&self,
room_id: &str,
message: Message,
except_id: &str,
) -> WebSocketResult<usize> {
let room = self
.rooms
.get(room_id)
.ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
let mut sent_count = 0;
for member_id in room.members() {
if member_id != except_id {
if let Some(conn) = self.connections.get(&member_id) {
if conn.send(message.clone()).is_ok() {
sent_count += 1;
}
}
}
}
Ok(sent_count)
}
pub fn broadcast_all(&self, message: Message) -> usize {
let mut sent_count = 0;
for conn in self.connections.iter() {
if conn.send(message.clone()).is_ok() {
sent_count += 1;
}
}
sent_count
}
pub fn room_ids(&self) -> Vec<RoomId> {
self.rooms.iter().map(|r| r.key().clone()).collect()
}
pub fn connection_ids(&self) -> Vec<ConnectionId> {
self.connections.iter().map(|c| c.key().clone()).collect()
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn room_count(&self) -> usize {
self.rooms.len()
}
}
impl Default for RoomManager {
fn default() -> Self {
Self::new()
}
}