armature_websocket/
room.rs

1//! Room-based message broadcasting.
2
3use crate::connection::{Connection, ConnectionId};
4use crate::error::{WebSocketError, WebSocketResult};
5use crate::message::Message;
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::Arc;
9
10/// Unique identifier for a room.
11pub type RoomId = String;
12
13/// A room for grouping WebSocket connections.
14#[derive(Debug)]
15pub struct Room {
16    /// Room identifier
17    pub id: RoomId,
18    /// Connection IDs in this room
19    members: DashMap<ConnectionId, ()>,
20}
21
22impl Room {
23    /// Create a new room.
24    pub fn new(id: RoomId) -> Self {
25        Self {
26            id,
27            members: DashMap::new(),
28        }
29    }
30
31    /// Add a connection to the room.
32    pub fn join(&self, connection_id: ConnectionId) {
33        self.members.insert(connection_id, ());
34    }
35
36    /// Remove a connection from the room.
37    pub fn leave(&self, connection_id: &str) -> bool {
38        self.members.remove(connection_id).is_some()
39    }
40
41    /// Check if a connection is in the room.
42    pub fn contains(&self, connection_id: &str) -> bool {
43        self.members.contains_key(connection_id)
44    }
45
46    /// Get the number of connections in the room.
47    pub fn len(&self) -> usize {
48        self.members.len()
49    }
50
51    /// Check if the room is empty.
52    pub fn is_empty(&self) -> bool {
53        self.members.is_empty()
54    }
55
56    /// Get all connection IDs in the room.
57    pub fn members(&self) -> Vec<ConnectionId> {
58        self.members.iter().map(|r| r.key().clone()).collect()
59    }
60}
61
62/// Manages rooms and their members.
63pub struct RoomManager {
64    /// All rooms
65    rooms: DashMap<RoomId, Arc<Room>>,
66    /// Mapping of connection ID to room IDs
67    connection_rooms: DashMap<ConnectionId, HashSet<RoomId>>,
68    /// All connections
69    connections: DashMap<ConnectionId, Connection>,
70}
71
72impl RoomManager {
73    /// Create a new room manager.
74    pub fn new() -> Self {
75        Self {
76            rooms: DashMap::new(),
77            connection_rooms: DashMap::new(),
78            connections: DashMap::new(),
79        }
80    }
81
82    /// Register a connection.
83    pub fn register_connection(&self, connection: Connection) {
84        let id = connection.id.clone();
85        self.connections.insert(id.clone(), connection);
86        self.connection_rooms.insert(id, HashSet::new());
87    }
88
89    /// Unregister a connection and remove it from all rooms.
90    pub fn unregister_connection(&self, connection_id: &str) {
91        if let Some((_, room_ids)) = self.connection_rooms.remove(connection_id) {
92            for room_id in room_ids {
93                if let Some(room) = self.rooms.get(&room_id) {
94                    room.leave(connection_id);
95                }
96                // Atomically remove room if empty (avoids TOCTOU race)
97                self.rooms.remove_if(&room_id, |_, room| room.is_empty());
98            }
99        }
100        self.connections.remove(connection_id);
101    }
102
103    /// Get a connection by ID.
104    pub fn get_connection(&self, connection_id: &str) -> Option<Connection> {
105        self.connections.get(connection_id).map(|c| c.clone())
106    }
107
108    /// Create a room if it doesn't exist.
109    pub fn create_room(&self, room_id: RoomId) -> Arc<Room> {
110        self.rooms
111            .entry(room_id.clone())
112            .or_insert_with(|| Arc::new(Room::new(room_id)))
113            .clone()
114    }
115
116    /// Get a room by ID.
117    pub fn get_room(&self, room_id: &str) -> Option<Arc<Room>> {
118        self.rooms.get(room_id).map(|r| r.clone())
119    }
120
121    /// Delete a room.
122    pub fn delete_room(&self, room_id: &str) -> bool {
123        if let Some((_, room)) = self.rooms.remove(room_id) {
124            // Remove room from all connection's room sets
125            for member_id in room.members() {
126                if let Some(mut rooms) = self.connection_rooms.get_mut(&member_id) {
127                    rooms.remove(room_id);
128                }
129            }
130            true
131        } else {
132            false
133        }
134    }
135
136    /// Join a connection to a room.
137    pub fn join_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
138        if !self.connections.contains_key(connection_id) {
139            return Err(WebSocketError::ConnectionNotFound(connection_id.to_string()));
140        }
141
142        let room = self.create_room(room_id.to_string());
143        room.join(connection_id.to_string());
144
145        if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
146            rooms.insert(room_id.to_string());
147        }
148
149        Ok(())
150    }
151
152    /// Remove a connection from a room.
153    pub fn leave_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
154        if let Some(room) = self.rooms.get(room_id) {
155            room.leave(connection_id);
156        }
157
158        if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
159            rooms.remove(room_id);
160        }
161
162        // Atomically remove room if empty (avoids TOCTOU race)
163        self.rooms.remove_if(room_id, |_, room| room.is_empty());
164
165        Ok(())
166    }
167
168    /// Broadcast a message to all connections in a room.
169    pub fn broadcast_to_room(&self, room_id: &str, message: Message) -> WebSocketResult<usize> {
170        let room = self
171            .rooms
172            .get(room_id)
173            .ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
174
175        let mut sent_count = 0;
176        for member_id in room.members() {
177            if let Some(conn) = self.connections.get(&member_id) {
178                if conn.send(message.clone()).is_ok() {
179                    sent_count += 1;
180                }
181            }
182        }
183
184        Ok(sent_count)
185    }
186
187    /// Broadcast a message to all connections in a room except one.
188    pub fn broadcast_to_room_except(
189        &self,
190        room_id: &str,
191        message: Message,
192        except_id: &str,
193    ) -> WebSocketResult<usize> {
194        let room = self
195            .rooms
196            .get(room_id)
197            .ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
198
199        let mut sent_count = 0;
200        for member_id in room.members() {
201            if member_id != except_id {
202                if let Some(conn) = self.connections.get(&member_id) {
203                    if conn.send(message.clone()).is_ok() {
204                        sent_count += 1;
205                    }
206                }
207            }
208        }
209
210        Ok(sent_count)
211    }
212
213    /// Broadcast a message to all connections.
214    pub fn broadcast_all(&self, message: Message) -> usize {
215        let mut sent_count = 0;
216        for conn in self.connections.iter() {
217            if conn.send(message.clone()).is_ok() {
218                sent_count += 1;
219            }
220        }
221        sent_count
222    }
223
224    /// Get all room IDs.
225    pub fn room_ids(&self) -> Vec<RoomId> {
226        self.rooms.iter().map(|r| r.key().clone()).collect()
227    }
228
229    /// Get all connection IDs.
230    pub fn connection_ids(&self) -> Vec<ConnectionId> {
231        self.connections.iter().map(|c| c.key().clone()).collect()
232    }
233
234    /// Get the total number of connections.
235    pub fn connection_count(&self) -> usize {
236        self.connections.len()
237    }
238
239    /// Get the total number of rooms.
240    pub fn room_count(&self) -> usize {
241        self.rooms.len()
242    }
243}
244
245impl Default for RoomManager {
246    fn default() -> Self {
247        Self::new()
248    }
249}