Skip to main content

wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use futures_util::{
6    SinkExt,
7    StreamExt,
8};
9use http::{
10    HeaderMap,
11    Uri,
12};
13use hyper::upgrade::{
14    OnUpgrade,
15    Upgraded,
16};
17use hyper_util::rt::TokioIo;
18use kikiutils::{
19    atomic::enum_cell::AtomicEnumCell,
20    types::fx_collections::FxDashMap,
21};
22use num_enum::{
23    IntoPrimitive,
24    TryFromPrimitive,
25};
26use roaring::RoaringTreemap;
27use serde::Serialize;
28use tokio::{
29    join,
30    select,
31    spawn,
32    sync::Mutex,
33    task::JoinSet,
34    time::timeout,
35};
36use tokio_tungstenite::{
37    WebSocketStream,
38    tungstenite::{
39        Message,
40        protocol::Role,
41    },
42};
43
44pub(crate) mod builder;
45mod config;
46pub mod operators;
47
48use self::{
49    config::WsIoServerNamespaceConfig,
50    operators::broadcast::WsIoServerNamespaceBroadcastOperator,
51};
52use crate::{
53    WsIoServer,
54    connection::WsIoServerConnection,
55    core::packet::WsIoPacket,
56    runtime::{
57        WsIoServerRuntime,
58        WsIoServerRuntimeStatus,
59    },
60};
61
62// Enums
63#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66    Running,
67    Stopped,
68    Stopping,
69}
70
71// Structs
72pub struct WsIoServerNamespace {
73    pub(crate) config: WsIoServerNamespaceConfig,
74    connection_ids: ArcSwap<RoaringTreemap>,
75    connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
76    connection_task_set: Mutex<JoinSet<()>>,
77    rooms: FxDashMap<String, RoaringTreemap>,
78    runtime: Arc<WsIoServerRuntime>,
79    status: AtomicEnumCell<NamespaceStatus>,
80}
81
82impl WsIoServerNamespace {
83    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
84        Arc::new(Self {
85            config,
86            connection_ids: ArcSwap::new(Arc::new(RoaringTreemap::new())),
87            connections: FxDashMap::default(),
88            connection_task_set: Mutex::new(JoinSet::new()),
89            rooms: FxDashMap::default(),
90            runtime,
91            status: AtomicEnumCell::new(NamespaceStatus::Running),
92        })
93    }
94
95    // Private methods
96    async fn handle_upgraded_request(
97        self: &Arc<Self>,
98        headers: HeaderMap,
99        request_uri: Uri,
100        upgraded: Upgraded,
101    ) -> Result<()> {
102        // Create ws stream
103        let mut ws_stream =
104            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
105                .await;
106
107        // Check runtime and namespace status
108        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
109            ws_stream
110                .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
111                .await?;
112
113            let _ = ws_stream.close(None).await;
114            return Ok(());
115        }
116
117        // Create connection
118        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
119
120        // Split ws stream and spawn read and write tasks
121        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
122        let connection_clone = connection.clone();
123        let mut read_ws_stream_task = spawn(async move {
124            while let Some(message) = ws_stream_reader.next().await {
125                if match message {
126                    Ok(Message::Binary(bytes)) => {
127                        // Treat any single-byte binary frame as a client heartbeat and ignore it
128                        if bytes.len() == 1 {
129                            continue;
130                        }
131
132                        connection_clone.handle_incoming_packet(&bytes).await
133                    }
134                    Ok(Message::Close(_)) => break,
135                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
136                    Err(_) => break,
137                    _ => Ok(()),
138                }
139                .is_err()
140                {
141                    break;
142                }
143            }
144        });
145
146        let mut write_ws_stream_task = spawn(async move {
147            while let Some(message) = message_rx.recv().await {
148                let message = (*message).clone();
149                let is_close = matches!(message, Message::Close(_));
150                if ws_stream_writer.send(message).await.is_err() {
151                    break;
152                }
153
154                if is_close {
155                    let _ = ws_stream_writer.close().await;
156                    break;
157                }
158            }
159        });
160
161        // Try to init connection
162        match connection.init().await {
163            Ok(_) => {
164                // Wait for either read or write task to finish
165                select! {
166                    _ = &mut read_ws_stream_task => {
167                        write_ws_stream_task.abort();
168                    },
169                    _ = &mut write_ws_stream_task => {
170                        read_ws_stream_task.abort();
171                    },
172                }
173            }
174            Err(_) => {
175                // Close connection
176                read_ws_stream_task.abort();
177                connection.close();
178                let _ = join!(read_ws_stream_task, write_ws_stream_task);
179            }
180        }
181
182        // Cleanup connection
183        connection.cleanup().await;
184        Ok(())
185    }
186
187    // Protected methods
188    #[inline]
189    pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
190        self.rooms.entry(room_name.into()).or_default().insert(connection_id);
191    }
192
193    #[inline]
194    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
195        let bytes = self.config.packet_codec.encode(packet)?;
196        Ok(Arc::new(match self.config.packet_codec.is_text() {
197            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
198            false => Message::Binary(bytes.into()),
199        }))
200    }
201
202    pub(crate) async fn handle_on_upgrade_request(
203        self: &Arc<Self>,
204        headers: HeaderMap,
205        on_upgrade: OnUpgrade,
206        request_uri: Uri,
207    ) {
208        let namespace = self.clone();
209        self.connection_task_set.lock().await.spawn(async move {
210            if let Ok(Ok(upgraded)) = timeout(namespace.config.http_request_upgrade_timeout, on_upgrade).await {
211                let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
212            }
213        });
214    }
215
216    #[inline]
217    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
218        self.connections.insert(connection.id(), connection.clone());
219        self.runtime.insert_connection_id(connection.id());
220        self.connection_ids.rcu(|old_connection_ids| {
221            let mut new_connection_ids = (**old_connection_ids).clone();
222            new_connection_ids.insert(connection.id());
223            new_connection_ids
224        });
225    }
226
227    #[inline]
228    pub(crate) fn remove_connection(&self, id: u64) {
229        self.connections.remove(&id);
230        self.runtime.remove_connection_id(id);
231        self.connection_ids.rcu(|old_connection_ids| {
232            let mut new_connection_ids = (**old_connection_ids).clone();
233            new_connection_ids.remove(id);
234            new_connection_ids
235        });
236    }
237
238    #[inline]
239    pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
240        if let Some(mut entry) = self.rooms.get_mut(room_name) {
241            entry.remove(connection_id);
242        }
243
244        self.rooms.remove_if(room_name, |_, entry| entry.is_empty());
245    }
246
247    // Public methods
248    pub async fn close_all(self: &Arc<Self>) {
249        WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
250    }
251
252    #[inline]
253    pub fn connection_count(&self) -> usize {
254        self.connections.len()
255    }
256
257    pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
258        WsIoServerNamespaceBroadcastOperator::new(self.clone())
259            .disconnect()
260            .await
261    }
262
263    pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
264        WsIoServerNamespaceBroadcastOperator::new(self.clone())
265            .emit(event, data)
266            .await
267    }
268
269    #[inline]
270    pub fn except(
271        self: &Arc<Self>,
272        room_names: impl IntoIterator<Item = impl Into<String>>,
273    ) -> WsIoServerNamespaceBroadcastOperator {
274        WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
275    }
276
277    #[inline]
278    pub fn path(&self) -> &str {
279        &self.config.path
280    }
281
282    #[inline]
283    pub fn server(&self) -> WsIoServer {
284        WsIoServer(self.runtime.clone())
285    }
286
287    pub async fn shutdown(self: &Arc<Self>) {
288        match self.status.get() {
289            NamespaceStatus::Stopped => return,
290            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
291            _ => unreachable!(),
292        }
293
294        self.close_all().await;
295        let mut connection_task_set = self.connection_task_set.lock().await;
296        while connection_task_set.join_next().await.is_some() {}
297
298        self.status.store(NamespaceStatus::Stopped);
299    }
300
301    #[inline]
302    pub fn to(
303        self: &Arc<Self>,
304        room_names: impl IntoIterator<Item = impl Into<String>>,
305    ) -> WsIoServerNamespaceBroadcastOperator {
306        WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use std::time::Duration;
313
314    use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
315
316    use super::*;
317    use crate::{
318        config::WsIoServerConfig,
319        core::packet::codecs::WsIoPacketCodec,
320    };
321
322    fn create_test_namespace() -> Arc<WsIoServerNamespace> {
323        let runtime = WsIoServerRuntime::new(WsIoServerConfig {
324            broadcast_concurrency_limit: 16,
325            http_request_upgrade_timeout: Duration::from_secs(3),
326            init_request_handler_timeout: Duration::from_secs(3),
327            init_response_handler_timeout: Duration::from_secs(3),
328            init_response_timeout: Duration::from_secs(3),
329            middleware_execution_timeout: Duration::from_secs(3),
330            on_close_handler_timeout: Duration::from_secs(3),
331            on_connect_handler_timeout: Duration::from_secs(3),
332            packet_codec: WsIoPacketCodec::SerdeJson,
333            request_path: "/socket".into(),
334            websocket_config: WebSocketConfig::default(),
335        });
336        runtime.new_namespace_builder("/test").register().unwrap()
337    }
338
339    #[tokio::test]
340    async fn test_namespace_new() {
341        let namespace = create_test_namespace();
342        assert_eq!(namespace.path(), "/test");
343        assert_eq!(namespace.connection_count(), 0);
344    }
345
346    #[tokio::test]
347    async fn test_namespace_connection_count() {
348        let namespace = create_test_namespace();
349        assert_eq!(namespace.connection_count(), 0);
350    }
351
352    #[tokio::test]
353    async fn test_namespace_server() {
354        let namespace = create_test_namespace();
355        namespace.server();
356    }
357
358    #[tokio::test]
359    async fn test_namespace_to_broadcast_operator() {
360        let namespace = create_test_namespace();
361        namespace.to(["room1", "room2"]);
362    }
363
364    #[tokio::test]
365    async fn test_namespace_except_broadcast_operator() {
366        let namespace = create_test_namespace();
367        namespace.except(["room1", "room2"]);
368    }
369
370    #[tokio::test]
371    async fn test_namespace_add_remove_connection_id_to_room() {
372        let namespace = create_test_namespace();
373        namespace.add_connection_id_to_room("room1", 1);
374        namespace.add_connection_id_to_room("room1", 2);
375        namespace.add_connection_id_to_room("room2", 3);
376
377        // Remove should work
378        namespace.remove_connection_id_from_room("room1", 1);
379        namespace.remove_connection_id_from_room("room1", 2);
380        namespace.remove_connection_id_from_room("room2", 3);
381    }
382
383    #[tokio::test]
384    async fn test_namespace_remove_connection_id_from_empty_room() {
385        let namespace = create_test_namespace();
386        // Removing from non-existent room should not panic
387        namespace.remove_connection_id_from_room("nonexistent", 1);
388    }
389
390    #[tokio::test]
391    async fn test_namespace_encode_packet_to_message() {
392        let namespace = create_test_namespace();
393        let packet = WsIoPacket::new_disconnect();
394        namespace.encode_packet_to_message(&packet).unwrap();
395    }
396
397    #[tokio::test]
398    async fn test_namespace_shutdown_idempotent() {
399        let namespace = create_test_namespace();
400        namespace.clone().shutdown().await;
401        // Shutting down again should be safe
402        namespace.shutdown().await;
403    }
404
405    #[tokio::test]
406    async fn test_broadcast_operator_new() {
407        let namespace = create_test_namespace();
408        // Just verify we can create an operator
409        namespace.to(["room1", "room2"]);
410    }
411
412    #[tokio::test]
413    async fn test_broadcast_operator_to_chaining() {
414        let namespace = create_test_namespace();
415        // Chaining should work - just verify it doesn't panic
416        namespace.to(["room1"]).to(["room2"]);
417    }
418
419    #[tokio::test]
420    async fn test_broadcast_operator_except_chaining() {
421        let namespace = create_test_namespace();
422        // Chaining should work - just verify it doesn't panic
423        namespace.except(["room1"]).except(["room2"]);
424    }
425
426    #[tokio::test]
427    async fn test_broadcast_operator_except_connection_ids() {
428        let namespace = create_test_namespace();
429        // except_connection_ids is on the broadcast operator, not namespace
430        namespace
431            .clone()
432            .except([1.to_string()])
433            .except_connection_ids([1, 2, 3]);
434    }
435
436    #[tokio::test]
437    async fn test_broadcast_operator_to_with_empty_rooms() {
438        let namespace = create_test_namespace();
439        // Empty rooms - should still work (broadcast to all)
440        namespace.to(Vec::<String>::new());
441    }
442
443    #[tokio::test]
444    async fn test_broadcast_operator_combined() {
445        let namespace = create_test_namespace();
446        // Combined chaining should work without panicking
447        namespace
448            .to(["room1", "room2"])
449            .except(["room3"])
450            .except_connection_ids([100]);
451    }
452
453    #[tokio::test]
454    async fn test_broadcast_operator_disconnect_with_no_connections() {
455        let namespace = create_test_namespace();
456        // disconnect with no connections should return Ok
457        let op = namespace.to(["room1"]);
458        let result = op.clone().disconnect().await;
459        assert!(result.is_ok());
460    }
461
462    #[tokio::test]
463    async fn test_broadcast_operator_emit_requires_running() {
464        let namespace = create_test_namespace();
465        // Shutdown to make status invalid
466        namespace.clone().shutdown().await;
467
468        let op = namespace.to(["room1"]);
469        let result = op.emit("event", Option::<&()>::None).await;
470        assert!(result.is_err());
471        let err_msg = result.unwrap_err().to_string();
472        assert!(err_msg.contains("invalid status"));
473    }
474
475    #[tokio::test]
476    async fn test_broadcast_operator_close_is_noop_when_empty() {
477        let namespace = create_test_namespace();
478        // close with no connections should not panic
479        let op = namespace.to(["room1"]);
480        op.clone().close().await;
481    }
482}