shaperail_runtime/ws/
room.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use tokio::sync::{mpsc, RwLock};
5
6pub type SessionId = String;
8
9pub type SessionSender = mpsc::UnboundedSender<String>;
11
12#[derive(Clone)]
17pub struct RoomManager {
18 rooms: Arc<RwLock<HashMap<String, HashSet<SessionId>>>>,
20 sessions: Arc<RwLock<HashMap<SessionId, SessionSender>>>,
22 session_rooms: Arc<RwLock<HashMap<SessionId, HashSet<String>>>>,
24}
25
26impl RoomManager {
27 pub fn new() -> Self {
29 Self {
30 rooms: Arc::new(RwLock::new(HashMap::new())),
31 sessions: Arc::new(RwLock::new(HashMap::new())),
32 session_rooms: Arc::new(RwLock::new(HashMap::new())),
33 }
34 }
35
36 pub async fn register_session(&self, session_id: &str, sender: SessionSender) {
38 self.sessions
39 .write()
40 .await
41 .insert(session_id.to_string(), sender);
42 self.session_rooms
43 .write()
44 .await
45 .insert(session_id.to_string(), HashSet::new());
46 }
47
48 pub async fn remove_session(&self, session_id: &str) {
50 self.sessions.write().await.remove(session_id);
51
52 let rooms = self
53 .session_rooms
54 .write()
55 .await
56 .remove(session_id)
57 .unwrap_or_default();
58
59 let mut room_map = self.rooms.write().await;
60 for room in rooms {
61 if let Some(members) = room_map.get_mut(&room) {
62 members.remove(session_id);
63 if members.is_empty() {
64 room_map.remove(&room);
65 }
66 }
67 }
68 }
69
70 pub async fn subscribe(&self, session_id: &str, room: &str) {
72 self.rooms
73 .write()
74 .await
75 .entry(room.to_string())
76 .or_default()
77 .insert(session_id.to_string());
78
79 if let Some(session_rooms) = self.session_rooms.write().await.get_mut(session_id) {
80 session_rooms.insert(room.to_string());
81 }
82 }
83
84 pub async fn unsubscribe(&self, session_id: &str, room: &str) {
86 let mut room_map = self.rooms.write().await;
87 if let Some(members) = room_map.get_mut(room) {
88 members.remove(session_id);
89 if members.is_empty() {
90 room_map.remove(room);
91 }
92 }
93
94 if let Some(session_rooms) = self.session_rooms.write().await.get_mut(session_id) {
95 session_rooms.remove(room);
96 }
97 }
98
99 pub async fn broadcast_to_room(&self, room: &str, message: &str) {
101 let rooms = self.rooms.read().await;
102 let sessions = self.sessions.read().await;
103
104 if let Some(members) = rooms.get(room) {
105 for session_id in members {
106 if let Some(sender) = sessions.get(session_id) {
107 let _ = sender.send(message.to_string());
109 }
110 }
111 }
112 }
113
114 pub async fn session_count(&self) -> usize {
116 self.sessions.read().await.len()
117 }
118
119 pub async fn room_member_count(&self, room: &str) -> usize {
121 self.rooms
122 .read()
123 .await
124 .get(room)
125 .map(|s| s.len())
126 .unwrap_or(0)
127 }
128}
129
130impl Default for RoomManager {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[tokio::test]
141 async fn register_and_remove_session() {
142 let mgr = RoomManager::new();
143 let (tx, _rx) = mpsc::unbounded_channel();
144
145 mgr.register_session("s1", tx).await;
146 assert_eq!(mgr.session_count().await, 1);
147
148 mgr.remove_session("s1").await;
149 assert_eq!(mgr.session_count().await, 0);
150 }
151
152 #[tokio::test]
153 async fn subscribe_and_broadcast() {
154 let mgr = RoomManager::new();
155 let (tx1, mut rx1) = mpsc::unbounded_channel();
156 let (tx2, mut rx2) = mpsc::unbounded_channel();
157
158 mgr.register_session("s1", tx1).await;
159 mgr.register_session("s2", tx2).await;
160
161 mgr.subscribe("s1", "org:123").await;
162 mgr.subscribe("s2", "org:123").await;
163
164 assert_eq!(mgr.room_member_count("org:123").await, 2);
165
166 mgr.broadcast_to_room("org:123", r#"{"hello":"world"}"#)
167 .await;
168
169 assert_eq!(rx1.recv().await.unwrap(), r#"{"hello":"world"}"#);
170 assert_eq!(rx2.recv().await.unwrap(), r#"{"hello":"world"}"#);
171 }
172
173 #[tokio::test]
174 async fn unsubscribe_stops_broadcast() {
175 let mgr = RoomManager::new();
176 let (tx1, mut rx1) = mpsc::unbounded_channel();
177 let (tx2, _rx2) = mpsc::unbounded_channel();
178
179 mgr.register_session("s1", tx1).await;
180 mgr.register_session("s2", tx2).await;
181
182 mgr.subscribe("s1", "room:a").await;
183 mgr.subscribe("s2", "room:a").await;
184 mgr.unsubscribe("s2", "room:a").await;
185
186 assert_eq!(mgr.room_member_count("room:a").await, 1);
187
188 mgr.broadcast_to_room("room:a", "msg").await;
189 assert_eq!(rx1.recv().await.unwrap(), "msg");
190 }
191
192 #[tokio::test]
193 async fn remove_session_cleans_up_rooms() {
194 let mgr = RoomManager::new();
195 let (tx, _rx) = mpsc::unbounded_channel();
196
197 mgr.register_session("s1", tx).await;
198 mgr.subscribe("s1", "room:a").await;
199 mgr.subscribe("s1", "room:b").await;
200
201 assert_eq!(mgr.room_member_count("room:a").await, 1);
202 assert_eq!(mgr.room_member_count("room:b").await, 1);
203
204 mgr.remove_session("s1").await;
205
206 assert_eq!(mgr.room_member_count("room:a").await, 0);
207 assert_eq!(mgr.room_member_count("room:b").await, 0);
208 }
209
210 #[tokio::test]
211 async fn broadcast_to_empty_room() {
212 let mgr = RoomManager::new();
213 mgr.broadcast_to_room("nonexistent", "msg").await;
215 }
216
217 #[tokio::test]
218 async fn disconnect_during_broadcast() {
219 let mgr = RoomManager::new();
220 let (tx, rx) = mpsc::unbounded_channel();
221
222 mgr.register_session("s1", tx).await;
223 mgr.subscribe("s1", "room:a").await;
224
225 drop(rx);
227
228 mgr.broadcast_to_room("room:a", "msg").await;
230 }
231}