1use std::sync::Arc;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use futures_util::{
6 SinkExt,
7 StreamExt,
8};
9use http::{
10 HeaderMap,
11 Uri,
12};
13use hyper::upgrade::{
14 OnUpgrade,
15 Upgraded,
16};
17use hyper_util::rt::TokioIo;
18use kikiutils::{
19 atomic::enum_cell::AtomicEnumCell,
20 types::fx_collections::FxDashMap,
21};
22use num_enum::{
23 IntoPrimitive,
24 TryFromPrimitive,
25};
26use roaring::RoaringTreemap;
27use serde::Serialize;
28use tokio::{
29 join,
30 select,
31 spawn,
32 sync::Mutex,
33 task::JoinSet,
34 time::timeout,
35};
36use tokio_tungstenite::{
37 WebSocketStream,
38 tungstenite::{
39 Message,
40 protocol::Role,
41 },
42};
43
44pub(crate) mod builder;
45mod config;
46pub mod operators;
47
48use self::{
49 config::WsIoServerNamespaceConfig,
50 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
51};
52use crate::{
53 WsIoServer,
54 connection::WsIoServerConnection,
55 core::packet::WsIoPacket,
56 runtime::{
57 WsIoServerRuntime,
58 WsIoServerRuntimeStatus,
59 },
60};
61
62#[repr(u8)]
64#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
65enum NamespaceStatus {
66 Running,
67 Stopped,
68 Stopping,
69}
70
71pub struct WsIoServerNamespace {
73 pub(crate) config: WsIoServerNamespaceConfig,
74 connection_ids: ArcSwap<RoaringTreemap>,
75 connections: FxDashMap<u64, Arc<WsIoServerConnection>>,
76 connection_task_set: Mutex<JoinSet<()>>,
77 rooms: FxDashMap<String, RoaringTreemap>,
78 runtime: Arc<WsIoServerRuntime>,
79 status: AtomicEnumCell<NamespaceStatus>,
80}
81
82impl WsIoServerNamespace {
83 fn new(config: WsIoServerNamespaceConfig, runtime: Arc<WsIoServerRuntime>) -> Arc<Self> {
84 Arc::new(Self {
85 config,
86 connection_ids: ArcSwap::new(Arc::new(RoaringTreemap::new())),
87 connections: FxDashMap::default(),
88 connection_task_set: Mutex::new(JoinSet::new()),
89 rooms: FxDashMap::default(),
90 runtime,
91 status: AtomicEnumCell::new(NamespaceStatus::Running),
92 })
93 }
94
95 async fn handle_upgraded_request(
97 self: &Arc<Self>,
98 headers: HeaderMap,
99 request_uri: Uri,
100 upgraded: Upgraded,
101 ) -> Result<()> {
102 let mut ws_stream =
104 WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, Some(self.config.websocket_config))
105 .await;
106
107 if !self.runtime.status.is(WsIoServerRuntimeStatus::Running) || !self.status.is(NamespaceStatus::Running) {
109 ws_stream
110 .send((*self.encode_packet_to_message(&WsIoPacket::new_disconnect())?).clone())
111 .await?;
112
113 let _ = ws_stream.close(None).await;
114 return Ok(());
115 }
116
117 let (connection, mut message_rx) = WsIoServerConnection::new(headers, self.clone(), request_uri);
119
120 let (mut ws_stream_writer, mut ws_stream_reader) = ws_stream.split();
122 let connection_clone = connection.clone();
123 let mut read_ws_stream_task = spawn(async move {
124 while let Some(message) = ws_stream_reader.next().await {
125 if match message {
126 Ok(Message::Binary(bytes)) => {
127 if bytes.len() == 1 {
129 continue;
130 }
131
132 connection_clone.handle_incoming_packet(&bytes).await
133 }
134 Ok(Message::Close(_)) => break,
135 Ok(Message::Text(text)) => connection_clone.handle_incoming_packet(text.as_bytes()).await,
136 Err(_) => break,
137 _ => Ok(()),
138 }
139 .is_err()
140 {
141 break;
142 }
143 }
144 });
145
146 let mut write_ws_stream_task = spawn(async move {
147 while let Some(message) = message_rx.recv().await {
148 let message = (*message).clone();
149 let is_close = matches!(message, Message::Close(_));
150 if ws_stream_writer.send(message).await.is_err() {
151 break;
152 }
153
154 if is_close {
155 let _ = ws_stream_writer.close().await;
156 break;
157 }
158 }
159 });
160
161 match connection.init().await {
163 Ok(_) => {
164 select! {
166 _ = &mut read_ws_stream_task => {
167 write_ws_stream_task.abort();
168 },
169 _ = &mut write_ws_stream_task => {
170 read_ws_stream_task.abort();
171 },
172 }
173 }
174 Err(_) => {
175 read_ws_stream_task.abort();
177 connection.close();
178 let _ = join!(read_ws_stream_task, write_ws_stream_task);
179 }
180 }
181
182 connection.cleanup().await;
184 Ok(())
185 }
186
187 #[inline]
189 pub(crate) fn add_connection_id_to_room(&self, room_name: &str, connection_id: u64) {
190 self.rooms.entry(room_name.into()).or_default().insert(connection_id);
191 }
192
193 #[inline]
194 pub(crate) fn encode_packet_to_message(&self, packet: &WsIoPacket) -> Result<Arc<Message>> {
195 let bytes = self.config.packet_codec.encode(packet)?;
196 Ok(Arc::new(match self.config.packet_codec.is_text() {
197 true => Message::Text(unsafe { String::from_utf8_unchecked(bytes).into() }),
198 false => Message::Binary(bytes.into()),
199 }))
200 }
201
202 pub(crate) async fn handle_on_upgrade_request(
203 self: &Arc<Self>,
204 headers: HeaderMap,
205 on_upgrade: OnUpgrade,
206 request_uri: Uri,
207 ) {
208 let namespace = self.clone();
209 self.connection_task_set.lock().await.spawn(async move {
210 if let Ok(Ok(upgraded)) = timeout(namespace.config.http_request_upgrade_timeout, on_upgrade).await {
211 let _ = namespace.handle_upgraded_request(headers, request_uri, upgraded).await;
212 }
213 });
214 }
215
216 #[inline]
217 pub(crate) fn insert_connection(&self, connection: Arc<WsIoServerConnection>) {
218 self.connections.insert(connection.id(), connection.clone());
219 self.runtime.insert_connection_id(connection.id());
220 self.connection_ids.rcu(|old_connection_ids| {
221 let mut new_connection_ids = (**old_connection_ids).clone();
222 new_connection_ids.insert(connection.id());
223 new_connection_ids
224 });
225 }
226
227 #[inline]
228 pub(crate) fn remove_connection(&self, id: u64) {
229 self.connections.remove(&id);
230 self.runtime.remove_connection_id(id);
231 self.connection_ids.rcu(|old_connection_ids| {
232 let mut new_connection_ids = (**old_connection_ids).clone();
233 new_connection_ids.remove(id);
234 new_connection_ids
235 });
236 }
237
238 #[inline]
239 pub(crate) fn remove_connection_id_from_room(&self, room_name: &str, connection_id: u64) {
240 if let Some(mut entry) = self.rooms.get_mut(room_name) {
241 entry.remove(connection_id);
242 }
243
244 self.rooms.remove_if(room_name, |_, entry| entry.is_empty());
245 }
246
247 pub async fn close_all(self: &Arc<Self>) {
249 WsIoServerNamespaceBroadcastOperator::new(self.clone()).close().await;
250 }
251
252 #[inline]
253 pub fn connection_count(&self) -> usize {
254 self.connections.len()
255 }
256
257 pub async fn disconnect_all(self: &Arc<Self>) -> Result<()> {
258 WsIoServerNamespaceBroadcastOperator::new(self.clone())
259 .disconnect()
260 .await
261 }
262
263 pub async fn emit<D: Serialize>(self: &Arc<Self>, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
264 WsIoServerNamespaceBroadcastOperator::new(self.clone())
265 .emit(event, data)
266 .await
267 }
268
269 #[inline]
270 pub fn except(
271 self: &Arc<Self>,
272 room_names: impl IntoIterator<Item = impl Into<String>>,
273 ) -> WsIoServerNamespaceBroadcastOperator {
274 WsIoServerNamespaceBroadcastOperator::new(self.clone()).except(room_names)
275 }
276
277 #[inline]
278 pub fn path(&self) -> &str {
279 &self.config.path
280 }
281
282 #[inline]
283 pub fn server(&self) -> WsIoServer {
284 WsIoServer(self.runtime.clone())
285 }
286
287 pub async fn shutdown(self: &Arc<Self>) {
288 match self.status.get() {
289 NamespaceStatus::Stopped => return,
290 NamespaceStatus::Running => self.status.store(NamespaceStatus::Stopping),
291 _ => unreachable!(),
292 }
293
294 self.close_all().await;
295 let mut connection_task_set = self.connection_task_set.lock().await;
296 while connection_task_set.join_next().await.is_some() {}
297
298 self.status.store(NamespaceStatus::Stopped);
299 }
300
301 #[inline]
302 pub fn to(
303 self: &Arc<Self>,
304 room_names: impl IntoIterator<Item = impl Into<String>>,
305 ) -> WsIoServerNamespaceBroadcastOperator {
306 WsIoServerNamespaceBroadcastOperator::new(self.clone()).to(room_names)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use std::time::Duration;
313
314 use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
315
316 use super::*;
317 use crate::{
318 config::WsIoServerConfig,
319 core::packet::codecs::WsIoPacketCodec,
320 };
321
322 fn create_test_namespace() -> Arc<WsIoServerNamespace> {
323 let runtime = WsIoServerRuntime::new(WsIoServerConfig {
324 broadcast_concurrency_limit: 16,
325 http_request_upgrade_timeout: Duration::from_secs(3),
326 init_request_handler_timeout: Duration::from_secs(3),
327 init_response_handler_timeout: Duration::from_secs(3),
328 init_response_timeout: Duration::from_secs(3),
329 middleware_execution_timeout: Duration::from_secs(3),
330 on_close_handler_timeout: Duration::from_secs(3),
331 on_connect_handler_timeout: Duration::from_secs(3),
332 packet_codec: WsIoPacketCodec::SerdeJson,
333 request_path: "/socket".into(),
334 websocket_config: WebSocketConfig::default(),
335 });
336 runtime.new_namespace_builder("/test").register().unwrap()
337 }
338
339 #[tokio::test]
340 async fn test_namespace_new() {
341 let namespace = create_test_namespace();
342 assert_eq!(namespace.path(), "/test");
343 assert_eq!(namespace.connection_count(), 0);
344 }
345
346 #[tokio::test]
347 async fn test_namespace_connection_count() {
348 let namespace = create_test_namespace();
349 assert_eq!(namespace.connection_count(), 0);
350 }
351
352 #[tokio::test]
353 async fn test_namespace_server() {
354 let namespace = create_test_namespace();
355 namespace.server();
356 }
357
358 #[tokio::test]
359 async fn test_namespace_to_broadcast_operator() {
360 let namespace = create_test_namespace();
361 namespace.to(["room1", "room2"]);
362 }
363
364 #[tokio::test]
365 async fn test_namespace_except_broadcast_operator() {
366 let namespace = create_test_namespace();
367 namespace.except(["room1", "room2"]);
368 }
369
370 #[tokio::test]
371 async fn test_namespace_add_remove_connection_id_to_room() {
372 let namespace = create_test_namespace();
373 namespace.add_connection_id_to_room("room1", 1);
374 namespace.add_connection_id_to_room("room1", 2);
375 namespace.add_connection_id_to_room("room2", 3);
376
377 namespace.remove_connection_id_from_room("room1", 1);
379 namespace.remove_connection_id_from_room("room1", 2);
380 namespace.remove_connection_id_from_room("room2", 3);
381 }
382
383 #[tokio::test]
384 async fn test_namespace_remove_connection_id_from_empty_room() {
385 let namespace = create_test_namespace();
386 namespace.remove_connection_id_from_room("nonexistent", 1);
388 }
389
390 #[tokio::test]
391 async fn test_namespace_encode_packet_to_message() {
392 let namespace = create_test_namespace();
393 let packet = WsIoPacket::new_disconnect();
394 namespace.encode_packet_to_message(&packet).unwrap();
395 }
396
397 #[tokio::test]
398 async fn test_namespace_shutdown_idempotent() {
399 let namespace = create_test_namespace();
400 namespace.clone().shutdown().await;
401 namespace.shutdown().await;
403 }
404
405 #[tokio::test]
406 async fn test_broadcast_operator_new() {
407 let namespace = create_test_namespace();
408 namespace.to(["room1", "room2"]);
410 }
411
412 #[tokio::test]
413 async fn test_broadcast_operator_to_chaining() {
414 let namespace = create_test_namespace();
415 namespace.to(["room1"]).to(["room2"]);
417 }
418
419 #[tokio::test]
420 async fn test_broadcast_operator_except_chaining() {
421 let namespace = create_test_namespace();
422 namespace.except(["room1"]).except(["room2"]);
424 }
425
426 #[tokio::test]
427 async fn test_broadcast_operator_except_connection_ids() {
428 let namespace = create_test_namespace();
429 namespace
431 .clone()
432 .except([1.to_string()])
433 .except_connection_ids([1, 2, 3]);
434 }
435
436 #[tokio::test]
437 async fn test_broadcast_operator_to_with_empty_rooms() {
438 let namespace = create_test_namespace();
439 namespace.to(Vec::<String>::new());
441 }
442
443 #[tokio::test]
444 async fn test_broadcast_operator_combined() {
445 let namespace = create_test_namespace();
446 namespace
448 .to(["room1", "room2"])
449 .except(["room3"])
450 .except_connection_ids([100]);
451 }
452
453 #[tokio::test]
454 async fn test_broadcast_operator_disconnect_with_no_connections() {
455 let namespace = create_test_namespace();
456 let op = namespace.to(["room1"]);
458 let result = op.clone().disconnect().await;
459 assert!(result.is_ok());
460 }
461
462 #[tokio::test]
463 async fn test_broadcast_operator_emit_requires_running() {
464 let namespace = create_test_namespace();
465 namespace.clone().shutdown().await;
467
468 let op = namespace.to(["room1"]);
469 let result = op.emit("event", Option::<&()>::None).await;
470 assert!(result.is_err());
471 let err_msg = result.unwrap_err().to_string();
472 assert!(err_msg.contains("invalid status"));
473 }
474
475 #[tokio::test]
476 async fn test_broadcast_operator_close_is_noop_when_empty() {
477 let namespace = create_test_namespace();
478 let op = namespace.to(["room1"]);
480 op.clone().close().await;
481 }
482}