1use std::{
2 collections::HashMap,
3 error::Error,
4 fmt,
5 sync::atomic::{AtomicU32, Ordering},
6};
7
8use tokio::{
9 sync::{broadcast, RwLock},
10 task::JoinHandle,
11};
12
13struct Room<K, U, T> {
18 name: K,
19 tx: broadcast::Sender<T>,
20 inner_user: RwLock<HashMap<U, UserTask>>,
21 user_count: AtomicU32,
22}
23
24struct UserTask {
26 task: JoinHandle<()>,
27}
28
29pub struct RoomsManager<K, U, T> {
42 inner: RwLock<HashMap<K, Room<K, U, T>>>,
43 user_reciever: RwLock<HashMap<U, broadcast::Sender<T>>>,
44}
45
46#[derive(Debug)]
47pub enum RoomError {
48 RoomNotFound,
50 MessageSendFail,
52 NotInitiated,
54}
55
56impl Error for RoomError {}
57
58impl fmt::Display for RoomError {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 match self {
61 RoomError::RoomNotFound => {
62 write!(f, "target room not found")
63 }
64 RoomError::NotInitiated => {
65 write!(f, "user is not initiated")
66 }
67 RoomError::MessageSendFail => {
68 write!(f, "failed to send message to the room")
69 }
70 }
71 }
72}
73
74pub struct UserReceiverGuard<'a, K, U, T>
75where
76 T: Clone + Send + 'static,
77 K: Eq + std::hash::Hash + Clone,
78 U: Eq + std::hash::Hash + Clone,
79{
80 receiver: broadcast::Receiver<T>,
81 user: U,
82 manager: &'a RoomsManager<K, U, T>,
83}
84
85impl<K, U, T> Room<K, U, T>
86where
87 T: Clone + Send + 'static,
88 K: Eq + std::hash::Hash,
89 U: Eq + std::hash::Hash,
90{
91 fn new(name: K, capacity: Option<usize>) -> Room<K, U, T> {
94 let (tx, _rx) = broadcast::channel(capacity.unwrap_or(100));
95
96 Room {
97 name,
98 tx,
99 inner_user: RwLock::new(HashMap::new()),
100 user_count: AtomicU32::new(0),
101 }
102 }
103
104 async fn join(&self, user: U, user_sender: broadcast::Sender<T>) {
107 let mut inner = self.inner_user.write().await;
108
109 match inner.entry(user) {
110 std::collections::hash_map::Entry::Occupied(_) => {}
111 std::collections::hash_map::Entry::Vacant(data) => {
112 let mut room_rec = self.get_sender().subscribe();
113
114 let task = tokio::spawn(async move {
115 while let Ok(data) = room_rec.recv().await {
116 let _ = user_sender.send(data);
117 }
118 });
119
120 data.insert(UserTask { task });
121
122 self.user_count.fetch_add(1, Ordering::SeqCst);
123 }
124 }
125 }
126
127 async fn leave(&self, user: U) {
130 let mut inner = self.inner_user.write().await;
131
132 match inner.entry(user) {
133 std::collections::hash_map::Entry::Vacant(_) => {}
134 std::collections::hash_map::Entry::Occupied(data) => {
135 let data = data.remove();
136
137 data.task.abort();
138
139 self.user_count.fetch_sub(1, Ordering::SeqCst);
140 }
141 }
142 }
143
144 fn blocking_leave(&self, user: U) {
145 let mut inner = self.inner_user.blocking_write();
146
147 match inner.entry(user) {
148 std::collections::hash_map::Entry::Vacant(_) => {}
149 std::collections::hash_map::Entry::Occupied(data) => {
150 let data = data.remove();
151
152 data.task.abort();
153
154 self.user_count.fetch_sub(1, Ordering::SeqCst);
155 }
156 }
157 }
158
159 async fn clear_tasks(&self) {
160 let mut inner = self.inner_user.write().await;
161
162 inner.values().for_each(|value| {
163 value.task.abort();
164 });
165
166 inner.clear();
167
168 self.user_count.store(0, Ordering::SeqCst);
169 }
170
171 async fn contains_user(&self, user: &U) -> bool {
173 let inner = self.inner_user.read().await;
174
175 inner.contains_key(user)
176 }
177
178 fn is_empty(&self) -> bool {
180 self.user_count.load(Ordering::SeqCst) == 0
181 }
182
183 fn get_sender(&self) -> broadcast::Sender<T> {
185 self.tx.clone()
186 }
187
188 fn send(&self, data: T) -> Result<usize, broadcast::error::SendError<T>> {
190 self.tx.send(data)
191 }
192
193 async fn user_count(&self) -> u32 {
195 self.user_count.load(Ordering::SeqCst)
196 }
197}
198
199impl<K, U, T> RoomsManager<K, U, T>
200where
201 T: Clone + Send + 'static,
202 K: Eq + std::hash::Hash + Clone,
203 U: Eq + std::hash::Hash + Clone,
204{
205 pub fn new() -> Self {
206 RoomsManager {
207 inner: RwLock::new(HashMap::new()),
208 user_reciever: RwLock::new(HashMap::new()),
209 }
210 }
211
212 pub async fn new_room(&self, name: K, capacity: Option<usize>) {
213 let mut rooms = self.inner.write().await;
214
215 rooms.insert(name.clone(), Room::new(name, capacity));
216 }
217
218 pub async fn room_exists(&self, name: &K) -> bool {
219 let rooms = self.inner.read().await;
220
221 rooms.get(name).is_some()
222 }
223
224 pub async fn join_or_create(&self, user: U, room: K) -> Result<(), RoomError> {
225 match self.room_exists(&room).await {
226 true => self.join_room(room, user).await,
227 false => {
228 self.new_room(room.clone(), None).await;
229
230 self.join_room(room, user).await
231 }
232 }
233 }
234
235 pub async fn send_message_to_room(&self, name: &K, data: T) -> Result<usize, RoomError> {
239 let rooms = self.inner.read().await;
240
241 rooms
242 .get(name)
243 .ok_or(RoomError::RoomNotFound)?
244 .send(data)
245 .map_err(|_| RoomError::MessageSendFail)
246 }
247
248 pub async fn init_user(
250 &self,
251 user: U,
252 capacity: Option<usize>,
253 ) -> UserReceiverGuard<'_, K, U, T> {
254 let mut user_reciever = self.user_reciever.write().await;
255
256 match user_reciever.entry(user.clone()) {
257 std::collections::hash_map::Entry::Occupied(channel) => UserReceiverGuard {
258 user,
259 receiver: channel.get().subscribe(),
260 manager: self,
261 },
262 std::collections::hash_map::Entry::Vacant(v) => {
263 let (tx, rx) = broadcast::channel(capacity.unwrap_or(100));
264 v.insert(tx);
265
266 UserReceiverGuard {
267 user,
268 receiver: rx,
269 manager: self,
270 }
271 }
272 }
273 }
274
275 pub fn end_user(&self, user: U) {
277 let rooms = self.inner.blocking_write();
278 let mut user_reciever = self.user_reciever.blocking_write();
279
280 for (_key, room) in rooms.iter() {
281 room.blocking_leave(user.clone());
282 }
283
284 match user_reciever.entry(user.clone()) {
285 std::collections::hash_map::Entry::Occupied(o) => {
286 o.remove();
287 }
288 std::collections::hash_map::Entry::Vacant(_) => {}
289 }
290 }
291
292 pub async fn join_room(&self, name: K, user: U) -> Result<(), RoomError> {
294 let rooms = self.inner.read().await;
295 let user_reciever = self.user_reciever.read().await;
296
297 let user_reciever = user_reciever
298 .get(&user)
299 .ok_or(RoomError::NotInitiated)?
300 .clone();
301
302 rooms
303 .get(&name)
304 .ok_or(RoomError::RoomNotFound)?
305 .join(user.clone(), user_reciever)
306 .await;
307
308 Ok(())
309 }
310
311 pub async fn remove_room(&self, room: K) {
312 let mut rooms = self.inner.write().await;
313
314 match rooms.entry(room.clone()) {
315 std::collections::hash_map::Entry::Vacant(_) => {}
316 std::collections::hash_map::Entry::Occupied(el) => {
317 let room = el.remove();
318
319 room.clear_tasks().await;
320 }
321 }
322 }
323
324 pub async fn leave_room(&self, name: K, user: U) -> Result<(), RoomError> {
325 let rooms = self.inner.read().await;
326
327 rooms
328 .get(&name)
329 .ok_or(RoomError::RoomNotFound)?
330 .leave(user.clone())
331 .await;
332
333 Ok(())
334 }
335
336 pub async fn is_room_empty(&self, name: K) -> Result<bool, RoomError> {
337 let rooms = self.inner.read().await;
338
339 Ok(rooms.get(&name).ok_or(RoomError::RoomNotFound)?.is_empty())
340 }
341
342 pub async fn rooms_count(&self) -> usize {
343 let rooms = self.inner.read().await;
344
345 rooms.len()
346 }
347}
348
349impl<K, U, T> Default for RoomsManager<K, U, T>
350where
351 T: Clone + Send + 'static,
352 K: Eq + std::hash::Hash + Clone,
353 U: Eq + std::hash::Hash + Clone,
354{
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360impl<K, U, T> std::ops::Deref for UserReceiverGuard<'_, K, U, T>
361where
362 T: Clone + Send + 'static,
363 K: Eq + std::hash::Hash + Clone,
364 U: Eq + std::hash::Hash + Clone,
365{
366 type Target = broadcast::Receiver<T>;
367
368 fn deref(&self) -> &Self::Target {
369 &self.receiver
370 }
371}
372
373impl<K, U, T> std::ops::DerefMut for UserReceiverGuard<'_, K, U, T>
374where
375 T: Clone + Send + 'static,
376 K: Eq + std::hash::Hash + Clone,
377 U: Eq + std::hash::Hash + Clone,
378{
379 fn deref_mut(&mut self) -> &mut Self::Target {
380 &mut self.receiver
381 }
382}
383
384impl<K, U, T> Drop for UserReceiverGuard<'_, K, U, T>
385where
386 T: Clone + Send + 'static,
387 K: Eq + std::hash::Hash + Clone,
388 U: Eq + std::hash::Hash + Clone,
389{
390 fn drop(&mut self) {
391 self.manager.end_user(self.user.clone());
392 }
393}