Skip to main content

wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use futures_util::{
6    SinkExt,
7    StreamExt,
8};
9use http::{
10    HeaderMap,
11    Uri,
12};
13use hyper::upgrade::{
14    OnUpgrade,
15    Upgraded,
16};
17use hyper_util::rt::TokioIo;
18use kikiutils::{
19    atomic::enum_cell::AtomicEnumCell,
20    types::fx_collections::FxDashMap,
21};
22use num_enum::{
23    IntoPrimitive,
24    TryFromPrimitive,
25};
26use roaring::RoaringTreemap;
27use serde::Serialize;
28use tokio::{
29    join,
30    select,
31    spawn,
32    sync::Mutex,
33    task::JoinSet,
34};
35use tokio_tungstenite::{
36    WebSocketStream,
37    tungstenite::{
38        Message,
39        protocol::Role,
40    },
41};
42
43pub(crate) mod builder;
44mod config;
45pub mod operators;
46
47use self::{
48    config::WsIoServerNamespaceConfig,
49    operators::broadcast::WsIoServerNamespaceBroadcastOperator,
50};
51use crate::{
52    WsIoServer,
53    connection::WsIoServerConnection,
54    core::packet::WsIoPacket,
55    runtime::{
56        WsIoServerRuntime,
57        WsIoServerRuntimeStatus,
58    },
59};
60
61// Enums
62#[repr(u8)]
63#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
64enum NamespaceStatus {
65    Running,
66    Stopped,
67    Stopping,
68}
69
70// Structs
71pub struct WsIoServerNamespace {
72    pub(crate) config: WsIoServerNamespaceConfig,
73    connection_ids: ArcSwap<RoaringTreemap>,
74    connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75    connection_task_set: Mutex<JoinSet<()>>,
76    rooms: FxDashMap<String, RoaringTreemap>,
77    runtime: Arc<WsIoServerRuntime>,
78    status: AtomicEnumCell<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83        Arc::new(Self {
84            config,
85            connection_ids: ArcSwap::new(Arc::new(RoaringTreemap::new())),
86            connections: FxDashMap::default(),
87            connection_task_set: Mutex::new(JoinSet::new()),
88            rooms: FxDashMap::default(),
89            runtime,
90            status: AtomicEnumCell::new(NamespaceStatus::Running),
91        })
92    }
93
94    // Private methods
95    async fn handle_upgraded_request(
96        self: &Arc<Self>,
97        headers: HeaderMap,
98        request_uri: Uri,
99        upgraded: Upgraded,
100    ) -> Result<()> {
101        // Create ws stream
102        let mut ws_stream =
103            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
104                .await;
105
106        // Check runtime and namespace status
107        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
108            ws_stream
109                .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
110                .await?;
111
112            let _ = ws_stream.close(None).await;
113            return Ok(());
114        }
115
116        // Create connection
117        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
118
119        // Split ws stream and spawn read and write tasks
120        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
121        let connection_clone = connection.clone();
122        let mut read_ws_stream_task = spawn(async move {
123            while let Some(message) = ws_stream_reader.next().await {
124                if match message {
125                    Ok(Message::Binary(bytes)) => {
126                        // Treat any single-byte binary frame as a client heartbeat and ignore it
127                        if bytes.len() == 1 {
128                            continue;
129                        }
130
131                        connection_clone.handle_incoming_packet(&bytes).await
132                    }
133                    Ok(Message::Close(_)) => break,
134                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
135                    Err(_) => break,
136                    _ => Ok(()),
137                }
138                .is_err()
139                {
140                    break;
141                }
142            }
143        });
144
145        let mut write_ws_stream_task = spawn(async move {
146            while let Some(message) = message_rx.recv().await {
147                let message = (*message).clone();
148                let is_close = matches!(message, Message::Close(_));
149                if ws_stream_writer.send(message).await.is_err() {
150                    break;
151                }
152
153                if is_close {
154                    let _ = ws_stream_writer.close().await;
155                    break;
156                }
157            }
158        });
159
160        // Try to init connection
161        match connection.init().await {
162            Ok(_) => {
163                // Wait for either read or write task to finish
164                select! {
165                    _ = &mut read_ws_stream_task => {
166                        write_ws_stream_task.abort();
167                    },
168                    _ = &mut write_ws_stream_task => {
169                        read_ws_stream_task.abort();
170                    },
171                }
172            }
173            Err(_) => {
174                // Close connection
175                read_ws_stream_task.abort();
176                connection.close();
177                let _ = join!(read_ws_stream_task, write_ws_stream_task);
178            }
179        }
180
181        // Cleanup connection
182        connection.cleanup().await;
183        Ok(())
184    }
185
186    // Protected methods
187    #[inline]
188    pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
189        self.rooms.entry(room_name.into()).or_default().insert(connection_id);
190    }
191
192    #[inline]
193    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
194        let bytes = self.config.packet_codec.encode(packet)?;
195        Ok(Arc::new(match self.config.packet_codec.is_text() {
196            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
197            false => Message::Binary(bytes.into()),
198        }))
199    }
200
201    pub(crate) async fn handle_on_upgrade_request(
202        self: &Arc<Self>,
203        headers: HeaderMap,
204        on_upgrade: OnUpgrade,
205        request_uri: Uri,
206    ) {
207        let namespace = self.clone();
208        self.connection_task_set.lock().await.spawn(async move {
209            if let Ok(upgraded) = on_upgrade.await {
210                let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
211            }
212        });
213    }
214
215    #[inline]
216    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
217        self.connections.insert(connection.id(), connection.clone());
218        self.runtime.insert_connection_id(connection.id());
219        self.connection_ids.rcu(|old_connection_ids| {
220            let mut new_connection_ids = (**old_connection_ids).clone();
221            new_connection_ids.insert(connection.id());
222            new_connection_ids
223        });
224    }
225
226    #[inline]
227    pub(crate) fn remove_connection(&self, id: u64) {
228        self.connections.remove(&id);
229        self.runtime.remove_connection_id(id);
230        self.connection_ids.rcu(|old_connection_ids| {
231            let mut new_connection_ids = (**old_connection_ids).clone();
232            new_connection_ids.remove(id);
233            new_connection_ids
234        });
235    }
236
237    #[inline]
238    pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
239        if let Some(mut entry) = self.rooms.get_mut(room_name) {
240            entry.remove(connection_id);
241            if entry.is_empty() {
242                drop(entry);
243                self.rooms.remove(room_name);
244            }
245        }
246    }
247
248    // Public methods
249    pub async fn close_all(self: &Arc<Self>) {
250        WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
251    }
252
253    #[inline]
254    pub fn connection_count(&self) -> usize {
255        self.connections.len()
256    }
257
258    pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
259        WsIoServerNamespaceBroadcastOperator::new(self.clone())
260            .disconnect()
261            .await
262    }
263
264    pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
265        WsIoServerNamespaceBroadcastOperator::new(self.clone())
266            .emit(event, data)
267            .await
268    }
269
270    #[inline]
271    pub fn except(
272        self: &Arc<Self>,
273        room_names: impl IntoIterator<Item = impl Into<String>>,
274    ) -> WsIoServerNamespaceBroadcastOperator {
275        WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
276    }
277
278    #[inline]
279    pub fn path(&self) -> &str {
280        &self.config.path
281    }
282
283    #[inline]
284    pub fn server(&self) -> WsIoServer {
285        WsIoServer(self.runtime.clone())
286    }
287
288    pub async fn shutdown(self: &Arc<Self>) {
289        match self.status.get() {
290            NamespaceStatus::Stopped => return,
291            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
292            _ => unreachable!(),
293        }
294
295        self.close_all().await;
296        let mut connection_task_set = self.connection_task_set.lock().await;
297        while connection_task_set.join_next().await.is_some() {}
298
299        self.status.store(NamespaceStatus::Stopped);
300    }
301
302    #[inline]
303    pub fn to(
304        self: &Arc<Self>,
305        room_names: impl IntoIterator<Item = impl Into<String>>,
306    ) -> WsIoServerNamespaceBroadcastOperator {
307        WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
308    }
309}