Skip to main content

shaperail_runtime/ws/
room.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use tokio::sync::{mpsc, RwLock};
5
6/// A unique identifier for a connected WebSocket session.
7pub type SessionId = String;
8
9/// A sender that delivers text frames to a connected client.
10pub type SessionSender = mpsc::UnboundedSender<String>;
11
12/// Manages room subscriptions and message routing for a single channel.
13///
14/// Thread-safe via `Arc<RwLock<...>>` — designed for concurrent access
15/// from multiple WebSocket sessions and the Redis pub/sub listener.
16#[derive(Clone)]
17pub struct RoomManager {
18    /// room_name -> set of session IDs subscribed to that room.
19    rooms: Arc<RwLock<HashMap<String, HashSet<SessionId>>>>,
20    /// session_id -> sender for delivering messages to that session.
21    sessions: Arc<RwLock<HashMap<SessionId, SessionSender>>>,
22    /// session_id -> set of rooms that session is subscribed to.
23    session_rooms: Arc<RwLock<HashMap<SessionId, HashSet<String>>>>,
24}
25
26impl RoomManager {
27    /// Creates a new empty room manager.
28    pub fn new() -> Self {
29        Self {
30            rooms: Arc::new(RwLock::new(HashMap::new())),
31            sessions: Arc::new(RwLock::new(HashMap::new())),
32            session_rooms: Arc::new(RwLock::new(HashMap::new())),
33        }
34    }
35
36    /// Registers a new session with its message sender.
37    pub async fn register_session(&self, session_id: &str, sender: SessionSender) {
38        self.sessions
39            .write()
40            .await
41            .insert(session_id.to_string(), sender);
42        self.session_rooms
43            .write()
44            .await
45            .insert(session_id.to_string(), HashSet::new());
46    }
47
48    /// Removes a session and all its room subscriptions.
49    pub async fn remove_session(&self, session_id: &str) {
50        self.sessions.write().await.remove(session_id);
51
52        let rooms = self
53            .session_rooms
54            .write()
55            .await
56            .remove(session_id)
57            .unwrap_or_default();
58
59        let mut room_map = self.rooms.write().await;
60        for room in rooms {
61            if let Some(members) = room_map.get_mut(&room) {
62                members.remove(session_id);
63                if members.is_empty() {
64                    room_map.remove(&room);
65                }
66            }
67        }
68    }
69
70    /// Subscribes a session to a room.
71    pub async fn subscribe(&self, session_id: &str, room: &str) {
72        self.rooms
73            .write()
74            .await
75            .entry(room.to_string())
76            .or_default()
77            .insert(session_id.to_string());
78
79        if let Some(session_rooms) = self.session_rooms.write().await.get_mut(session_id) {
80            session_rooms.insert(room.to_string());
81        }
82    }
83
84    /// Unsubscribes a session from a room.
85    pub async fn unsubscribe(&self, session_id: &str, room: &str) {
86        let mut room_map = self.rooms.write().await;
87        if let Some(members) = room_map.get_mut(room) {
88            members.remove(session_id);
89            if members.is_empty() {
90                room_map.remove(room);
91            }
92        }
93
94        if let Some(session_rooms) = self.session_rooms.write().await.get_mut(session_id) {
95            session_rooms.remove(room);
96        }
97    }
98
99    /// Broadcasts a text message to all sessions subscribed to a room.
100    pub async fn broadcast_to_room(&self, room: &str, message: &str) {
101        let rooms = self.rooms.read().await;
102        let sessions = self.sessions.read().await;
103
104        if let Some(members) = rooms.get(room) {
105            for session_id in members {
106                if let Some(sender) = sessions.get(session_id) {
107                    // Ignore send errors — the session may have disconnected
108                    let _ = sender.send(message.to_string());
109                }
110            }
111        }
112    }
113
114    /// Returns the number of currently registered sessions.
115    pub async fn session_count(&self) -> usize {
116        self.sessions.read().await.len()
117    }
118
119    /// Returns the number of sessions in a specific room.
120    pub async fn room_member_count(&self, room: &str) -> usize {
121        self.rooms
122            .read()
123            .await
124            .get(room)
125            .map(|s| s.len())
126            .unwrap_or(0)
127    }
128}
129
130impl Default for RoomManager {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[tokio::test]
141    async fn register_and_remove_session() {
142        let mgr = RoomManager::new();
143        let (tx, _rx) = mpsc::unbounded_channel();
144
145        mgr.register_session("s1", tx).await;
146        assert_eq!(mgr.session_count().await, 1);
147
148        mgr.remove_session("s1").await;
149        assert_eq!(mgr.session_count().await, 0);
150    }
151
152    #[tokio::test]
153    async fn subscribe_and_broadcast() {
154        let mgr = RoomManager::new();
155        let (tx1, mut rx1) = mpsc::unbounded_channel();
156        let (tx2, mut rx2) = mpsc::unbounded_channel();
157
158        mgr.register_session("s1", tx1).await;
159        mgr.register_session("s2", tx2).await;
160
161        mgr.subscribe("s1", "org:123").await;
162        mgr.subscribe("s2", "org:123").await;
163
164        assert_eq!(mgr.room_member_count("org:123").await, 2);
165
166        mgr.broadcast_to_room("org:123", r#"{"hello":"world"}"#)
167            .await;
168
169        assert_eq!(rx1.recv().await.unwrap(), r#"{"hello":"world"}"#);
170        assert_eq!(rx2.recv().await.unwrap(), r#"{"hello":"world"}"#);
171    }
172
173    #[tokio::test]
174    async fn unsubscribe_stops_broadcast() {
175        let mgr = RoomManager::new();
176        let (tx1, mut rx1) = mpsc::unbounded_channel();
177        let (tx2, _rx2) = mpsc::unbounded_channel();
178
179        mgr.register_session("s1", tx1).await;
180        mgr.register_session("s2", tx2).await;
181
182        mgr.subscribe("s1", "room:a").await;
183        mgr.subscribe("s2", "room:a").await;
184        mgr.unsubscribe("s2", "room:a").await;
185
186        assert_eq!(mgr.room_member_count("room:a").await, 1);
187
188        mgr.broadcast_to_room("room:a", "msg").await;
189        assert_eq!(rx1.recv().await.unwrap(), "msg");
190    }
191
192    #[tokio::test]
193    async fn remove_session_cleans_up_rooms() {
194        let mgr = RoomManager::new();
195        let (tx, _rx) = mpsc::unbounded_channel();
196
197        mgr.register_session("s1", tx).await;
198        mgr.subscribe("s1", "room:a").await;
199        mgr.subscribe("s1", "room:b").await;
200
201        assert_eq!(mgr.room_member_count("room:a").await, 1);
202        assert_eq!(mgr.room_member_count("room:b").await, 1);
203
204        mgr.remove_session("s1").await;
205
206        assert_eq!(mgr.room_member_count("room:a").await, 0);
207        assert_eq!(mgr.room_member_count("room:b").await, 0);
208    }
209
210    #[tokio::test]
211    async fn broadcast_to_empty_room() {
212        let mgr = RoomManager::new();
213        // Should not panic
214        mgr.broadcast_to_room("nonexistent", "msg").await;
215    }
216
217    #[tokio::test]
218    async fn disconnect_during_broadcast() {
219        let mgr = RoomManager::new();
220        let (tx, rx) = mpsc::unbounded_channel();
221
222        mgr.register_session("s1", tx).await;
223        mgr.subscribe("s1", "room:a").await;
224
225        // Drop receiver to simulate disconnect
226        drop(rx);
227
228        // Should not panic — just ignore the send error
229        mgr.broadcast_to_room("room:a", "msg").await;
230    }
231}