Skip to main content

axum_ws_rooms/
lib.rs

1use std::{
2    collections::HashMap,
3    error::Error,
4    fmt,
5    sync::atomic::{AtomicU32, Ordering},
6};
7
8use tokio::{
9    sync::{broadcast, RwLock},
10    task::JoinHandle,
11};
12
13/// each room has a name and it contains `broadcast::sender<String>` which can be accessed
14/// by `get_sender` method and you can send message to a roome by calling `send` on room.
15/// each room counts how many user it has and there is a method to check if its empty
16/// each room track its joined users and stores spawned tasks handlers
17struct Room<K, U, T> {
18    name: K,
19    tx: broadcast::Sender<T>,
20    inner_user: RwLock<HashMap<U, UserTask>>,
21    user_count: AtomicU32,
22}
23
24/// struct that contains task handler that forwards messages
25struct UserTask {
26    task: JoinHandle<()>,
27}
28
29/// use in combination with `Arc` to share it between threads
30/// internally it uses `RwLock` so it can handle concurrent requests without a problem
31/// when a user connects to ws endpoint you have to call `init_user` and it gives you a guard that
32/// when dropped will remove user from all rooms
33/// # Generics
34/// `K` is type used to identify each room
35///
36/// `U` is type used to identify each user
37///
38/// `T` is message type that is sent between rooms and users
39/// # Examples
40/// examples are available in examples directory
41pub struct RoomsManager<K, U, T> {
42    inner: RwLock<HashMap<K, Room<K, U, T>>>,
43    user_reciever: RwLock<HashMap<U, broadcast::Sender<T>>>,
44}
45
46#[derive(Debug)]
47pub enum RoomError {
48    /// room does not exists
49    RoomNotFound,
50    /// can not send message to room
51    MessageSendFail,
52    /// you have not called init_user
53    NotInitiated,
54}
55
56impl Error for RoomError {}
57
58impl fmt::Display for RoomError {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        match self {
61            RoomError::RoomNotFound => {
62                write!(f, "target room not found")
63            }
64            RoomError::NotInitiated => {
65                write!(f, "user is not initiated")
66            }
67            RoomError::MessageSendFail => {
68                write!(f, "failed to send message to the room")
69            }
70        }
71    }
72}
73
74pub struct UserReceiverGuard<'a, K, U, T>
75where
76    T: Clone + Send + 'static,
77    K: Eq + std::hash::Hash + Clone,
78    U: Eq + std::hash::Hash + Clone,
79{
80    receiver: broadcast::Receiver<T>,
81    user: U,
82    manager: &'a RoomsManager<K, U, T>,
83}
84
85impl<K, U, T> Room<K, U, T>
86where
87    T: Clone + Send + 'static,
88    K: Eq + std::hash::Hash,
89    U: Eq + std::hash::Hash,
90{
91    /// creates new room with a given name
92    /// capacity is the underlying channel capacity and its default is 100
93    fn new(name: K, capacity: Option<usize>) -> Room<K, U, T> {
94        let (tx, _rx) = broadcast::channel(capacity.unwrap_or(100));
95
96        Room {
97            name,
98            tx,
99            inner_user: RwLock::new(HashMap::new()),
100            user_count: AtomicU32::new(0),
101        }
102    }
103
104    /// join the rooms with a unique user
105    /// if user has joined before, it does nothing
106    async fn join(&self, user: U, user_sender: broadcast::Sender<T>) {
107        let mut inner = self.inner_user.write().await;
108
109        match inner.entry(user) {
110            std::collections::hash_map::Entry::Occupied(_) => {}
111            std::collections::hash_map::Entry::Vacant(data) => {
112                let mut room_rec = self.get_sender().subscribe();
113
114                let task = tokio::spawn(async move {
115                    while let Ok(data) = room_rec.recv().await {
116                        let _ = user_sender.send(data);
117                    }
118                });
119
120                data.insert(UserTask { task });
121
122                self.user_count.fetch_add(1, Ordering::SeqCst);
123            }
124        }
125    }
126
127    /// leave the room with user
128    /// if user has left before it wont do anything
129    async fn leave(&self, user: U) {
130        let mut inner = self.inner_user.write().await;
131
132        match inner.entry(user) {
133            std::collections::hash_map::Entry::Vacant(_) => {}
134            std::collections::hash_map::Entry::Occupied(data) => {
135                let data = data.remove();
136
137                data.task.abort();
138
139                self.user_count.fetch_sub(1, Ordering::SeqCst);
140            }
141        }
142    }
143
144    fn blocking_leave(&self, user: U) {
145        let mut inner = self.inner_user.blocking_write();
146
147        match inner.entry(user) {
148            std::collections::hash_map::Entry::Vacant(_) => {}
149            std::collections::hash_map::Entry::Occupied(data) => {
150                let data = data.remove();
151
152                data.task.abort();
153
154                self.user_count.fetch_sub(1, Ordering::SeqCst);
155            }
156        }
157    }
158
159    async fn clear_tasks(&self) {
160        let mut inner = self.inner_user.write().await;
161
162        inner.values().for_each(|value| {
163            value.task.abort();
164        });
165
166        inner.clear();
167
168        self.user_count.store(0, Ordering::SeqCst);
169    }
170
171    /// check if user is in the room
172    async fn contains_user(&self, user: &U) -> bool {
173        let inner = self.inner_user.read().await;
174
175        inner.contains_key(user)
176    }
177
178    /// checks if room is empty
179    fn is_empty(&self) -> bool {
180        self.user_count.load(Ordering::SeqCst) == 0
181    }
182
183    /// get sender without joining room
184    fn get_sender(&self) -> broadcast::Sender<T> {
185        self.tx.clone()
186    }
187
188    ///send message to room
189    fn send(&self, data: T) -> Result<usize, broadcast::error::SendError<T>> {
190        self.tx.send(data)
191    }
192
193    /// get user count of room
194    async fn user_count(&self) -> u32 {
195        self.user_count.load(Ordering::SeqCst)
196    }
197}
198
199impl<K, U, T> RoomsManager<K, U, T>
200where
201    T: Clone + Send + 'static,
202    K: Eq + std::hash::Hash + Clone,
203    U: Eq + std::hash::Hash + Clone,
204{
205    pub fn new() -> Self {
206        RoomsManager {
207            inner: RwLock::new(HashMap::new()),
208            user_reciever: RwLock::new(HashMap::new()),
209        }
210    }
211
212    pub async fn new_room(&self, name: K, capacity: Option<usize>) {
213        let mut rooms = self.inner.write().await;
214
215        rooms.insert(name.clone(), Room::new(name, capacity));
216    }
217
218    pub async fn room_exists(&self, name: &K) -> bool {
219        let rooms = self.inner.read().await;
220
221        rooms.get(name).is_some()
222    }
223
224    pub async fn join_or_create(&self, user: U, room: K) -> Result<(), RoomError> {
225        match self.room_exists(&room).await {
226            true => self.join_room(room, user).await,
227            false => {
228                self.new_room(room.clone(), None).await;
229
230                self.join_room(room, user).await
231            }
232        }
233    }
234
235    /// send a message to a room
236    /// it will fail if there are no users in the room or
237    /// if room does not exists
238    pub async fn send_message_to_room(&self, name: &K, data: T) -> Result<usize, RoomError> {
239        let rooms = self.inner.read().await;
240
241        rooms
242            .get(name)
243            .ok_or(RoomError::RoomNotFound)?
244            .send(data)
245            .map_err(|_| RoomError::MessageSendFail)
246    }
247
248    /// call this at first of your code to initialize user notifier
249    pub async fn init_user(
250        &self,
251        user: U,
252        capacity: Option<usize>,
253    ) -> UserReceiverGuard<'_, K, U, T> {
254        let mut user_reciever = self.user_reciever.write().await;
255
256        match user_reciever.entry(user.clone()) {
257            std::collections::hash_map::Entry::Occupied(channel) => UserReceiverGuard {
258                user,
259                receiver: channel.get().subscribe(),
260                manager: self,
261            },
262            std::collections::hash_map::Entry::Vacant(v) => {
263                let (tx, rx) = broadcast::channel(capacity.unwrap_or(100));
264                v.insert(tx);
265
266                UserReceiverGuard {
267                    user,
268                    receiver: rx,
269                    manager: self,
270                }
271            }
272        }
273    }
274
275    /// call this at end of your code to remove user from all rooms
276    pub fn end_user(&self, user: U) {
277        let rooms = self.inner.blocking_write();
278        let mut user_reciever = self.user_reciever.blocking_write();
279
280        for (_key, room) in rooms.iter() {
281            room.blocking_leave(user.clone());
282        }
283
284        match user_reciever.entry(user.clone()) {
285            std::collections::hash_map::Entry::Occupied(o) => {
286                o.remove();
287            }
288            std::collections::hash_map::Entry::Vacant(_) => {}
289        }
290    }
291
292    /// join user to room
293    pub async fn join_room(&self, name: K, user: U) -> Result<(), RoomError> {
294        let rooms = self.inner.read().await;
295        let user_reciever = self.user_reciever.read().await;
296
297        let user_reciever = user_reciever
298            .get(&user)
299            .ok_or(RoomError::NotInitiated)?
300            .clone();
301
302        rooms
303            .get(&name)
304            .ok_or(RoomError::RoomNotFound)?
305            .join(user.clone(), user_reciever)
306            .await;
307
308        Ok(())
309    }
310
311    pub async fn remove_room(&self, room: K) {
312        let mut rooms = self.inner.write().await;
313
314        match rooms.entry(room.clone()) {
315            std::collections::hash_map::Entry::Vacant(_) => {}
316            std::collections::hash_map::Entry::Occupied(el) => {
317                let room = el.remove();
318
319                room.clear_tasks().await;
320            }
321        }
322    }
323
324    pub async fn leave_room(&self, name: K, user: U) -> Result<(), RoomError> {
325        let rooms = self.inner.read().await;
326
327        rooms
328            .get(&name)
329            .ok_or(RoomError::RoomNotFound)?
330            .leave(user.clone())
331            .await;
332
333        Ok(())
334    }
335
336    pub async fn is_room_empty(&self, name: K) -> Result<bool, RoomError> {
337        let rooms = self.inner.read().await;
338
339        Ok(rooms.get(&name).ok_or(RoomError::RoomNotFound)?.is_empty())
340    }
341
342    pub async fn rooms_count(&self) -> usize {
343        let rooms = self.inner.read().await;
344
345        rooms.len()
346    }
347}
348
349impl<K, U, T> Default for RoomsManager<K, U, T>
350where
351    T: Clone + Send + 'static,
352    K: Eq + std::hash::Hash + Clone,
353    U: Eq + std::hash::Hash + Clone,
354{
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360impl<K, U, T> std::ops::Deref for UserReceiverGuard<'_, K, U, T>
361where
362    T: Clone + Send + 'static,
363    K: Eq + std::hash::Hash + Clone,
364    U: Eq + std::hash::Hash + Clone,
365{
366    type Target = broadcast::Receiver<T>;
367
368    fn deref(&self) -> &Self::Target {
369        &self.receiver
370    }
371}
372
373impl<K, U, T> std::ops::DerefMut for UserReceiverGuard<'_, K, U, T>
374where
375    T: Clone + Send + 'static,
376    K: Eq + std::hash::Hash + Clone,
377    U: Eq + std::hash::Hash + Clone,
378{
379    fn deref_mut(&mut self) -> &mut Self::Target {
380        &mut self.receiver
381    }
382}
383
384impl<K, U, T> Drop for UserReceiverGuard<'_, K, U, T>
385where
386    T: Clone + Send + 'static,
387    K: Eq + std::hash::Hash + Clone,
388    U: Eq + std::hash::Hash + Clone,
389{
390    fn drop(&mut self) {
391        self.manager.end_user(self.user.clone());
392    }
393}