armature_websocket/
room.rs1use crate::connection::{Connection, ConnectionId};
4use crate::error::{WebSocketError, WebSocketResult};
5use crate::message::Message;
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::Arc;
9
10pub type RoomId = String;
12
13#[derive(Debug)]
15pub struct Room {
16 pub id: RoomId,
18 members: DashMap<ConnectionId, ()>,
20}
21
22impl Room {
23 pub fn new(id: RoomId) -> Self {
25 Self {
26 id,
27 members: DashMap::new(),
28 }
29 }
30
31 pub fn join(&self, connection_id: ConnectionId) {
33 self.members.insert(connection_id, ());
34 }
35
36 pub fn leave(&self, connection_id: &str) -> bool {
38 self.members.remove(connection_id).is_some()
39 }
40
41 pub fn contains(&self, connection_id: &str) -> bool {
43 self.members.contains_key(connection_id)
44 }
45
46 pub fn len(&self) -> usize {
48 self.members.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.members.is_empty()
54 }
55
56 pub fn members(&self) -> Vec<ConnectionId> {
58 self.members.iter().map(|r| r.key().clone()).collect()
59 }
60}
61
62pub struct RoomManager {
64 rooms: DashMap<RoomId, Arc<Room>>,
66 connection_rooms: DashMap<ConnectionId, HashSet<RoomId>>,
68 connections: DashMap<ConnectionId, Connection>,
70}
71
72impl RoomManager {
73 pub fn new() -> Self {
75 Self {
76 rooms: DashMap::new(),
77 connection_rooms: DashMap::new(),
78 connections: DashMap::new(),
79 }
80 }
81
82 pub fn register_connection(&self, connection: Connection) {
84 let id = connection.id.clone();
85 self.connections.insert(id.clone(), connection);
86 self.connection_rooms.insert(id, HashSet::new());
87 }
88
89 pub fn unregister_connection(&self, connection_id: &str) {
91 if let Some((_, room_ids)) = self.connection_rooms.remove(connection_id) {
92 for room_id in room_ids {
93 if let Some(room) = self.rooms.get(&room_id) {
94 room.leave(connection_id);
95 }
96 self.rooms.remove_if(&room_id, |_, room| room.is_empty());
98 }
99 }
100 self.connections.remove(connection_id);
101 }
102
103 pub fn get_connection(&self, connection_id: &str) -> Option<Connection> {
105 self.connections.get(connection_id).map(|c| c.clone())
106 }
107
108 pub fn create_room(&self, room_id: RoomId) -> Arc<Room> {
110 self.rooms
111 .entry(room_id.clone())
112 .or_insert_with(|| Arc::new(Room::new(room_id)))
113 .clone()
114 }
115
116 pub fn get_room(&self, room_id: &str) -> Option<Arc<Room>> {
118 self.rooms.get(room_id).map(|r| r.clone())
119 }
120
121 pub fn delete_room(&self, room_id: &str) -> bool {
123 if let Some((_, room)) = self.rooms.remove(room_id) {
124 for member_id in room.members() {
126 if let Some(mut rooms) = self.connection_rooms.get_mut(&member_id) {
127 rooms.remove(room_id);
128 }
129 }
130 true
131 } else {
132 false
133 }
134 }
135
136 pub fn join_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
138 if !self.connections.contains_key(connection_id) {
139 return Err(WebSocketError::ConnectionNotFound(connection_id.to_string()));
140 }
141
142 let room = self.create_room(room_id.to_string());
143 room.join(connection_id.to_string());
144
145 if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
146 rooms.insert(room_id.to_string());
147 }
148
149 Ok(())
150 }
151
152 pub fn leave_room(&self, connection_id: &str, room_id: &str) -> WebSocketResult<()> {
154 if let Some(room) = self.rooms.get(room_id) {
155 room.leave(connection_id);
156 }
157
158 if let Some(mut rooms) = self.connection_rooms.get_mut(connection_id) {
159 rooms.remove(room_id);
160 }
161
162 self.rooms.remove_if(room_id, |_, room| room.is_empty());
164
165 Ok(())
166 }
167
168 pub fn broadcast_to_room(&self, room_id: &str, message: Message) -> WebSocketResult<usize> {
170 let room = self
171 .rooms
172 .get(room_id)
173 .ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
174
175 let mut sent_count = 0;
176 for member_id in room.members() {
177 if let Some(conn) = self.connections.get(&member_id) {
178 if conn.send(message.clone()).is_ok() {
179 sent_count += 1;
180 }
181 }
182 }
183
184 Ok(sent_count)
185 }
186
187 pub fn broadcast_to_room_except(
189 &self,
190 room_id: &str,
191 message: Message,
192 except_id: &str,
193 ) -> WebSocketResult<usize> {
194 let room = self
195 .rooms
196 .get(room_id)
197 .ok_or_else(|| WebSocketError::RoomNotFound(room_id.to_string()))?;
198
199 let mut sent_count = 0;
200 for member_id in room.members() {
201 if member_id != except_id {
202 if let Some(conn) = self.connections.get(&member_id) {
203 if conn.send(message.clone()).is_ok() {
204 sent_count += 1;
205 }
206 }
207 }
208 }
209
210 Ok(sent_count)
211 }
212
213 pub fn broadcast_all(&self, message: Message) -> usize {
215 let mut sent_count = 0;
216 for conn in self.connections.iter() {
217 if conn.send(message.clone()).is_ok() {
218 sent_count += 1;
219 }
220 }
221 sent_count
222 }
223
224 pub fn room_ids(&self) -> Vec<RoomId> {
226 self.rooms.iter().map(|r| r.key().clone()).collect()
227 }
228
229 pub fn connection_ids(&self) -> Vec<ConnectionId> {
231 self.connections.iter().map(|c| c.key().clone()).collect()
232 }
233
234 pub fn connection_count(&self) -> usize {
236 self.connections.len()
237 }
238
239 pub fn room_count(&self) -> usize {
241 self.rooms.len()
242 }
243}
244
245impl Default for RoomManager {
246 fn default() -> Self {
247 Self::new()
248 }
249}