1use std::sync::Arc;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use futures_util::{
6 SinkExt,
7 StreamExt,
8};
9use http::{
10 HeaderMap,
11 Uri,
12};
13use hyper::upgrade::{
14 OnUpgrade,
15 Upgraded,
16};
17use hyper_util::rt::TokioIo;
18use kikiutils::{
19 atomic::enum_cell::AtomicEnumCell,
20 types::fx_collections::FxDashMap,
21};
22use num_enum::{
23 IntoPrimitive,
24 TryFromPrimitive,
25};
26use roaring::RoaringTreemap;
27use serde::Serialize;
28use tokio::{
29 join,
30 select,
31 spawn,
32 sync::Mutex,
33 task::JoinSet,
34};
35use tokio_tungstenite::{
36 WebSocketStream,
37 tungstenite::{
38 Message,
39 protocol::Role,
40 },
41};
42
43pub(crate) mod builder;
44mod config;
45pub mod operators;
46
47use self::{
48 config::WsIoServerNamespaceConfig,
49 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
50};
51use crate::{
52 WsIoServer,
53 connection::WsIoServerConnection,
54 core::packet::WsIoPacket,
55 runtime::{
56 WsIoServerRuntime,
57 WsIoServerRuntimeStatus,
58 },
59};
60
61#[repr(u8)]
63#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
64enum NamespaceStatus {
65 Running,
66 Stopped,
67 Stopping,
68}
69
70pub struct WsIoServerNamespace {
72 pub(crate) config: WsIoServerNamespaceConfig,
73 connection_ids: ArcSwap<RoaringTreemap>,
74 connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75 connection_task_set: Mutex<JoinSet<()>>,
76 rooms: FxDashMap<String, RoaringTreemap>,
77 runtime: Arc<WsIoServerRuntime>,
78 status: AtomicEnumCell<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83 Arc::new(Self {
84 config,
85 connection_ids: ArcSwap::new(Arc::new(RoaringTreemap::new())),
86 connections: FxDashMap::default(),
87 connection_task_set: Mutex::new(JoinSet::new()),
88 rooms: FxDashMap::default(),
89 runtime,
90 status: AtomicEnumCell::new(NamespaceStatus::Running),
91 })
92 }
93
94 async fn handle_upgraded_request(
96 self: &Arc<Self>,
97 headers: HeaderMap,
98 request_uri: Uri,
99 upgraded: Upgraded,
100 ) -> Result<()> {
101 let mut ws_stream =
103 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
104 .await;
105
106 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
108 ws_stream
109 .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
110 .await?;
111
112 let _ = ws_stream.close(None).await;
113 return Ok(());
114 }
115
116 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
118
119 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
121 let connection_clone = connection.clone();
122 let mut read_ws_stream_task = spawn(async move {
123 while let Some(message) = ws_stream_reader.next().await {
124 if match message {
125 Ok(Message::Binary(bytes)) => {
126 if bytes.len() == 1 {
128 continue;
129 }
130
131 connection_clone.handle_incoming_packet(&bytes).await
132 }
133 Ok(Message::Close(_)) => break,
134 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
135 Err(_) => break,
136 _ => Ok(()),
137 }
138 .is_err()
139 {
140 break;
141 }
142 }
143 });
144
145 let mut write_ws_stream_task = spawn(async move {
146 while let Some(message) = message_rx.recv().await {
147 let message = (*message).clone();
148 let is_close = matches!(message, Message::Close(_));
149 if ws_stream_writer.send(message).await.is_err() {
150 break;
151 }
152
153 if is_close {
154 let _ = ws_stream_writer.close().await;
155 break;
156 }
157 }
158 });
159
160 match connection.init().await {
162 Ok(_) => {
163 select! {
165 _ = &mut read_ws_stream_task => {
166 write_ws_stream_task.abort();
167 },
168 _ = &mut write_ws_stream_task => {
169 read_ws_stream_task.abort();
170 },
171 }
172 }
173 Err(_) => {
174 read_ws_stream_task.abort();
176 connection.close();
177 let _ = join!(read_ws_stream_task, write_ws_stream_task);
178 }
179 }
180
181 connection.cleanup().await;
183 Ok(())
184 }
185
186 #[inline]
188 pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
189 self.rooms.entry(room_name.into()).or_default().insert(connection_id);
190 }
191
192 #[inline]
193 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
194 let bytes = self.config.packet_codec.encode(packet)?;
195 Ok(Arc::new(match self.config.packet_codec.is_text() {
196 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
197 false => Message::Binary(bytes.into()),
198 }))
199 }
200
201 pub(crate) async fn handle_on_upgrade_request(
202 self: &Arc<Self>,
203 headers: HeaderMap,
204 on_upgrade: OnUpgrade,
205 request_uri: Uri,
206 ) {
207 let namespace = self.clone();
208 self.connection_task_set.lock().await.spawn(async move {
209 if let Ok(upgraded) = on_upgrade.await {
210 let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
211 }
212 });
213 }
214
215 #[inline]
216 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
217 self.connections.insert(connection.id(), connection.clone());
218 self.runtime.insert_connection_id(connection.id());
219 self.connection_ids.rcu(|old_connection_ids| {
220 let mut new_connection_ids = (**old_connection_ids).clone();
221 new_connection_ids.insert(connection.id());
222 new_connection_ids
223 });
224 }
225
226 #[inline]
227 pub(crate) fn remove_connection(&self, id: u64) {
228 self.connections.remove(&id);
229 self.runtime.remove_connection_id(id);
230 self.connection_ids.rcu(|old_connection_ids| {
231 let mut new_connection_ids = (**old_connection_ids).clone();
232 new_connection_ids.remove(id);
233 new_connection_ids
234 });
235 }
236
237 #[inline]
238 pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
239 if let Some(mut entry) = self.rooms.get_mut(room_name) {
240 entry.remove(connection_id);
241 if entry.is_empty() {
242 drop(entry);
243 self.rooms.remove(room_name);
244 }
245 }
246 }
247
248 pub async fn close_all(self: &Arc<Self>) {
250 WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
251 }
252
253 #[inline]
254 pub fn connection_count(&self) -> usize {
255 self.connections.len()
256 }
257
258 pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
259 WsIoServerNamespaceBroadcastOperator::new(self.clone())
260 .disconnect()
261 .await
262 }
263
264 pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
265 WsIoServerNamespaceBroadcastOperator::new(self.clone())
266 .emit(event, data)
267 .await
268 }
269
270 #[inline]
271 pub fn except(
272 self: &Arc<Self>,
273 room_names: impl IntoIterator<Item = impl Into<String>>,
274 ) -> WsIoServerNamespaceBroadcastOperator {
275 WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
276 }
277
278 #[inline]
279 pub fn path(&self) -> &str {
280 &self.config.path
281 }
282
283 #[inline]
284 pub fn server(&self) -> WsIoServer {
285 WsIoServer(self.runtime.clone())
286 }
287
288 pub async fn shutdown(self: &Arc<Self>) {
289 match self.status.get() {
290 NamespaceStatus::Stopped => return,
291 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
292 _ => unreachable!(),
293 }
294
295 self.close_all().await;
296 let mut connection_task_set = self.connection_task_set.lock().await;
297 while connection_task_set.join_next().await.is_some() {}
298
299 self.status.store(NamespaceStatus::Stopped);
300 }
301
302 #[inline]
303 pub fn to(
304 self: &Arc<Self>,
305 room_names: impl IntoIterator<Item = impl Into<String>>,
306 ) -> WsIoServerNamespaceBroadcastOperator {
307 WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
308 }
309}