1use std::sync::Arc;
2
3use anyhow::Result;
4use futures_util::{
5 SinkExt,
6 StreamExt,
7};
8use http::{
9 HeaderMap,
10 Uri,
11};
12use hyper::upgrade::{
13 OnUpgrade,
14 Upgraded,
15};
16use hyper_util::rt::TokioIo;
17use kikiutils::{
18 atomic::enum_cell::AtomicEnumCell,
19 types::fx_collections::{
20 FxDashMap,
21 FxDashSet,
22 },
23};
24use num_enum::{
25 IntoPrimitive,
26 TryFromPrimitive,
27};
28use serde::Serialize;
29use tokio::{
30 join,
31 select,
32 spawn,
33 sync::Mutex,
34 task::JoinSet,
35};
36use tokio_tungstenite::{
37 WebSocketStream,
38 tungstenite::{
39 Message,
40 protocol::Role,
41 },
42};
43
44pub(crate) mod builder;
45mod config;
46pub mod operators;
47
48use self::{
49 config::WsIoServerNamespaceConfig,
50 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
51};
52use crate::{
53 WsIoServer,
54 connection::WsIoServerConnection,
55 core::packet::WsIoPacket,
56 runtime::{
57 WsIoServerRuntime,
58 WsIoServerRuntimeStatus,
59 },
60};
61
62#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66 Running,
67 Stopped,
68 Stopping,
69}
70
71pub struct WsIoServerNamespace {
73 pub(crate) config: WsIoServerNamespaceConfig,
74 connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
75 connection_task_set: Mutex<JoinSet<()>>,
76 rooms: FxDashMap<String, Arc<FxDashSet<u64>>>,
77 runtime: Arc<WsIoServerRuntime>,
78 status: AtomicEnumCell<NamespaceStatus>,
79}
80
81impl WsIoServerNamespace {
82 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
83 Arc::new(Self {
84 config,
85 connections: FxDashMap::default(),
86 connection_task_set: Mutex::new(JoinSet::new()),
87 rooms: FxDashMap::default(),
88 runtime,
89 status: AtomicEnumCell::new(NamespaceStatus::Running),
90 })
91 }
92
93 async fn handle_upgraded_request(
95 self: &Arc<Self>,
96 headers: HeaderMap,
97 request_uri: Uri,
98 upgraded: Upgraded,
99 ) -> Result<()> {
100 let mut ws_stream =
102 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
103 .await;
104
105 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
107 ws_stream
108 .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
109 .await?;
110
111 let _ = ws_stream.close(None).await;
112 return Ok(());
113 }
114
115 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
117
118 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
120 let connection_clone = connection.clone();
121 let mut read_ws_stream_task = spawn(async move {
122 while let Some(message) = ws_stream_reader.next().await {
123 if match message {
124 Ok(Message::Binary(bytes)) => {
125 if bytes.len() == 1 {
127 continue;
128 }
129
130 connection_clone.handle_incoming_packet(&bytes).await
131 }
132 Ok(Message::Close(_)) => break,
133 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
134 Err(_) => break,
135 _ => Ok(()),
136 }
137 .is_err()
138 {
139 break;
140 }
141 }
142 });
143
144 let mut write_ws_stream_task = spawn(async move {
145 while let Some(message) = message_rx.recv().await {
146 let message = (*message).clone();
147 let is_close = matches!(message, Message::Close(_));
148 if ws_stream_writer.send(message).await.is_err() {
149 break;
150 }
151
152 if is_close {
153 let _ = ws_stream_writer.close().await;
154 break;
155 }
156 }
157 });
158
159 match connection.init().await {
161 Ok(_) => {
162 select! {
164 _ = &mut read_ws_stream_task => {
165 write_ws_stream_task.abort();
166 },
167 _ = &mut write_ws_stream_task => {
168 read_ws_stream_task.abort();
169 },
170 }
171 }
172 Err(_) => {
173 read_ws_stream_task.abort();
175 connection.close();
176 let _ = join!(read_ws_stream_task, write_ws_stream_task);
177 }
178 }
179
180 connection.cleanup().await;
182 Ok(())
183 }
184
185 #[inline]
187 pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
188 self.rooms
189 .entry(room_name.into())
190 .or_default()
191 .clone()
192 .insert(connection_id);
193 }
194
195 #[inline]
196 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
197 let bytes = self.config.packet_codec.encode(packet)?;
198 Ok(Arc::new(match self.config.packet_codec.is_text() {
199 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
200 false => Message::Binary(bytes.into()),
201 }))
202 }
203
204 pub(crate) async fn handle_on_upgrade_request(
205 self: &Arc<Self>,
206 headers: HeaderMap,
207 on_upgrade: OnUpgrade,
208 request_uri: Uri,
209 ) {
210 let namespace = self.clone();
211 self.connection_task_set.lock().await.spawn(async move {
212 if let Ok(upgraded) = on_upgrade.await {
213 let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
214 }
215 });
216 }
217
218 #[inline]
219 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
220 self.connections.insert(connection.id(), connection.clone());
221 self.runtime.insert_connection_id(connection.id());
222 }
223
224 #[inline]
225 pub(crate) fn remove_connection(&self, id: u64) {
226 self.connections.remove(&id);
227 self.runtime.remove_connection_id(id);
228 }
229
230 #[inline]
231 pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
232 if let Some(room) = self.rooms.get(room_name).map(|entry| entry.clone()) {
233 room.remove(&connection_id);
234 if room.is_empty() {
235 self.rooms.remove(room_name);
236 }
237 }
238 }
239
240 pub async fn close_all(self: &Arc<Self>) {
242 WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
243 }
244
245 #[inline]
246 pub fn connection_count(&self) -> usize {
247 self.connections.len()
248 }
249
250 pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
251 WsIoServerNamespaceBroadcastOperator::new(self.clone())
252 .disconnect()
253 .await
254 }
255
256 pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
257 WsIoServerNamespaceBroadcastOperator::new(self.clone())
258 .emit(event, data)
259 .await
260 }
261
262 #[inline]
263 pub fn except(
264 self: &Arc<Self>,
265 room_names: impl IntoIterator<Item = impl Into<String>>,
266 ) -> WsIoServerNamespaceBroadcastOperator {
267 WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
268 }
269
270 #[inline]
271 pub fn path(&self) -> &str {
272 &self.config.path
273 }
274
275 #[inline]
276 pub fn server(&self) -> WsIoServer {
277 WsIoServer(self.runtime.clone())
278 }
279
280 pub async fn shutdown(self: &Arc<Self>) {
281 match self.status.get() {
282 NamespaceStatus::Stopped => return,
283 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
284 _ => unreachable!(),
285 }
286
287 self.close_all().await;
288 let mut connection_task_set = self.connection_task_set.lock().await;
289 while connection_task_set.join_next().await.is_some() {}
290
291 self.status.store(NamespaceStatus::Stopped);
292 }
293
294 #[inline]
295 pub fn to(
296 self: &Arc<Self>,
297 room_names: impl IntoIterator<Item = impl Into<String>>,
298 ) -> WsIoServerNamespaceBroadcastOperator {
299 WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
300 }
301}