wsio_server/namespace/
mod.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5    SinkExt,
6    StreamExt,
7};
8use http::HeaderMap;
9use hyper::upgrade::{
10    OnUpgrade,
11    Upgraded,
12};
13use hyper_util::rt::TokioIo;
14use num_enum::{
15    IntoPrimitive,
16    TryFromPrimitive,
17};
18use serde::Serialize;
19use tokio::{
20    join,
21    select,
22    spawn,
23    sync::Mutex,
24    task::JoinSet,
25};
26use tokio_tungstenite::{
27    WebSocketStream,
28    tungstenite::{
29        Message,
30        protocol::Role,
31    },
32};
33
34pub(crate) mod builder;
35mod config;
36pub mod operators;
37
38use self::{
39    config::WsIoServerNamespaceConfig,
40    operators::broadcast::WsIoServerNamespaceBroadcastOperator,
41};
42use crate::{
43    WsIoServer,
44    connection::WsIoServerConnection,
45    core::{
46        atomic::status::AtomicStatus,
47        packet::WsIoPacket,
48        types::hashers::{
49            FxDashMap,
50            FxDashSet,
51        },
52    },
53    runtime::{
54        WsIoServerRuntime,
55        WsIoServerRuntimeStatus,
56    },
57};
58
59// Enums
60#[repr(u8)]
61#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
62enum NamespaceStatus {
63    Running,
64    Stopped,
65    Stopping,
66}
67
68// Structs
69pub struct WsIoServerNamespace {
70    pub(crate) config: WsIoServerNamespaceConfig,
71    connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
72    connection_task_set: Mutex<JoinSet<()>>,
73    rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
74    runtime: Arc<WsIoServerRuntime>,
75    status: AtomicStatus<NamespaceStatus>,
76}
77
78impl WsIoServerNamespace {
79    fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
80        Arc::new(Self {
81            config,
82            connections: FxDashMap::default(),
83            connection_task_set: Mutex::new(JoinSet::new()),
84            rooms: FxDashMap::default(),
85            runtime,
86            status: AtomicStatus::new(NamespaceStatus::Running),
87        })
88    }
89
90    // Private methods
91    async fn handle_upgraded_request(self: &Arc<Self>, headers: HeaderMap, upgraded: Upgraded) -> Result<()> {
92        // Create ws stream
93        let mut ws_stream =
94            WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
95                .await;
96
97        // Check runtime and namespace status
98        if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
99            ws_stream
100                .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
101                .await?;
102
103            let _ = ws_stream.close(None).await;
104            return Ok(());
105        }
106
107        // Create connection
108        let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone());
109
110        // Split ws stream and spawn read and write tasks
111        let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
112        let connection_clone = connection.clone();
113        let mut read_ws_stream_task = spawn(async move {
114            while let Some(message) = ws_stream_reader.next().await {
115                if match message {
116                    Ok(Message::Binary(bytes)) => connection_clone.handle_incoming_packet(&bytes).await,
117                    Ok(Message::Close(_)) => break,
118                    Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
119                    Err(_) => break,
120                    _ => Ok(()),
121                }
122                .is_err()
123                {
124                    break;
125                }
126            }
127        });
128
129        let mut write_ws_stream_task = spawn(async move {
130            while let Some(message) = message_rx.recv().await {
131                let message = (*message).clone();
132                let is_close = matches!(message, Message::Close(_));
133                if ws_stream_writer.send(message).await.is_err() {
134                    break;
135                }
136
137                if is_close {
138                    let _ = ws_stream_writer.close().await;
139                    break;
140                }
141            }
142        });
143
144        // Try to init connection
145        match connection.init().await {
146            Ok(_) => {
147                // Wait for either read or write task to finish
148                select! {
149                    _ = &mut read_ws_stream_task => {
150                        write_ws_stream_task.abort();
151                    },
152                    _ = &mut write_ws_stream_task => {
153                        read_ws_stream_task.abort();
154                    },
155                }
156            }
157            Err(_) => {
158                // Close connection
159                read_ws_stream_task.abort();
160                connection.close();
161                let _ = join!(read_ws_stream_task, write_ws_stream_task);
162            }
163        }
164
165        // Cleanup connection
166        connection.cleanup().await;
167        Ok(())
168    }
169
170    // Protected methods
171    #[inline]
172    pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
173        self.rooms
174            .entry(room_name.to_string())
175            .or_default()
176            .clone()
177            .insert(connection_id);
178    }
179
180    #[inline]
181    pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
182        let bytes = self.config.packet_codec.encode(packet)?;
183        Ok(Arc::new(match self.config.packet_codec.is_text() {
184            true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
185            false => Message::Binary(bytes.into()),
186        }))
187    }
188
189    pub(crate) async fn handle_on_upgrade_request(self: &Arc<Self>, headers: HeaderMap, on_upgrade: OnUpgrade) {
190        let namespace = self.clone();
191        self.connection_task_set.lock().await.spawn(async move {
192            if let Ok(upgraded) = on_upgrade.await {
193                let _ = namespace.handle_upgraded_request(headers, upgraded).await;
194            }
195        });
196    }
197
198    #[inline]
199    pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
200        self.connections.insert(connection.id(), connection.clone());
201        self.runtime.insert_connection_id(connection.id());
202    }
203
204    #[inline]
205    pub(crate) fn remove_connection(&self, id: u64) {
206        self.connections.remove(&id);
207        self.runtime.remove_connection_id(id);
208    }
209
210    #[inline]
211    pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
212        if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
213            room.remove(&connection_id);
214            if room.is_empty() {
215                self.rooms.remove(room_name);
216            }
217        }
218    }
219
220    // Public methods
221    #[inline]
222    pub fn connection_count(&self) -> usize {
223        self.connections.len()
224    }
225
226    pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
227        WsIoServerNamespaceBroadcastOperator::new(self.clone())
228            .emit(event, data)
229            .await
230    }
231
232    #[inline]
233    pub fn except<I: IntoIterator<Item = S>, S: AsRef<str>>(
234        self: &Arc<Self>,
235        room_names: I,
236    ) -> WsIoServerNamespaceBroadcastOperator {
237        WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
238    }
239
240    #[inline]
241    pub fn path(&self) -> &str {
242        &self.config.path
243    }
244
245    #[inline]
246    pub fn server(&self) -> WsIoServer {
247        WsIoServer(self.runtime.clone())
248    }
249
250    pub async fn shutdown(self: &Arc<Self>) {
251        match self.status.get() {
252            NamespaceStatus::Stopped => return,
253            NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
254            _ => unreachable!(),
255        }
256
257        let _ = WsIoServerNamespaceBroadcastOperator::new(self.clone())
258            .disconnect()
259            .await;
260
261        let mut connection_task_set = self.connection_task_set.lock().await;
262        while connection_task_set.join_next().await.is_some() {}
263
264        self.status.store(NamespaceStatus::Stopped);
265    }
266
267    #[inline]
268    pub fn to<I: IntoIterator<Item = S>, S: AsRef<str>>(
269        self: &Arc<Self>,
270        room_names: I,
271    ) -> WsIoServerNamespaceBroadcastOperator {
272        WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
273    }
274}