1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16 HeaderMap,
17 Uri,
18};
19use kikiutils::{
20 atomic::enum_cell::AtomicEnumCell,
21 types::fx_collections::FxDashSet,
22};
23use num_enum::{
24 IntoPrimitive,
25 TryFromPrimitive,
26};
27use serde::{
28 Serialize,
29 de::DeserializeOwned,
30};
31use tokio::{
32 spawn,
33 sync::{
34 Mutex,
35 mpsc::{
36 Receiver,
37 Sender,
38 channel,
39 },
40 },
41 task::JoinHandle,
42 time::{
43 sleep,
44 timeout,
45 },
46};
47use tokio_tungstenite::tungstenite::Message;
48use tokio_util::sync::CancellationToken;
49
50#[cfg(feature = "connection-extensions")]
51mod extensions;
52
53#[cfg(feature = "connection-extensions")]
54use self::extensions::ConnectionExtensions;
55use crate::{
56 WsIoServer,
57 core::{
58 channel_capacity_from_websocket_config,
59 event::registry::WsIoEventRegistry,
60 packet::{
61 WsIoPacket,
62 WsIoPacketType,
63 },
64 traits::task::spawner::TaskSpawner,
65 types::BoxAsyncUnaryResultHandler,
66 utils::task::abort_locked_task,
67 },
68 namespace::{
69 WsIoServerNamespace,
70 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71 },
72};
73
74#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionState {
78 Activating,
79 AwaitingInit,
80 Closed,
81 Closing,
82 Created,
83 Initiating,
84 Ready,
85}
86
87pub struct WsIoServerConnection {
89 cancel_token: ArcSwap<CancellationToken>,
90 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91 #[cfg(feature = "connection-extensions")]
92 extensions: ConnectionExtensions,
93 headers: HeaderMap,
94 id: u64,
95 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96 joined_rooms: FxDashSet<String>,
97 message_tx: Sender<Arc<Message>>,
98 namespace: Arc<WsIoServerNamespace>,
99 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100 request_uri: Uri,
101 state: AtomicEnumCell<ConnectionState>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105 #[inline]
106 fn cancel_token(&self) -> Arc<CancellationToken> {
107 self.cancel_token.load_full()
108 }
109}
110
111impl WsIoServerConnection {
112 #[inline]
113 pub(crate) fn new(
114 headers: HeaderMap,
115 namespace: Arc<WsIoServerNamespace>,
116 request_uri: Uri,
117 ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119 let (message_tx, message_rx) = channel(channel_capacity);
120 (
121 Arc::new(Self {
122 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123 event_registry: WsIoEventRegistry::new(),
124 #[cfg(feature = "connection-extensions")]
125 extensions: ConnectionExtensions::new(),
126 headers,
127 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128 init_timeout_task: Mutex::new(None),
129 joined_rooms: FxDashSet::default(),
130 message_tx,
131 namespace,
132 on_close_handler: Mutex::new(None),
133 request_uri,
134 state: AtomicEnumCell::new(ConnectionState::Created),
135 }),
136 message_rx,
137 )
138 }
139
140 #[inline]
142 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143 self.event_registry.dispatch_event_packet(
144 self.clone(),
145 event,
146 &self.namespace.config.packet_codec,
147 packet_data,
148 self,
149 );
150
151 Ok(())
152 }
153
154 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
155 let state = self.state.get();
157 match state {
158 ConnectionState::AwaitingInit => self.state.try_transition(state, ConnectionState::Initiating)?,
159 _ => bail!("Received init packet in invalid state: {state:?}"),
160 }
161
162 abort_locked_task(&self.init_timeout_task).await;
164
165 if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
167 timeout(
168 self.namespace.config.init_response_handler_timeout,
169 init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
170 )
171 .await??
172 }
173
174 self.state
176 .try_transition(ConnectionState::Initiating, ConnectionState::Activating)?;
177
178 if let Some(middleware) = &self.namespace.config.middleware {
180 timeout(
181 self.namespace.config.middleware_execution_timeout,
182 middleware(self.clone()),
183 )
184 .await??;
185
186 self.state.ensure(ConnectionState::Activating, |state| {
188 format!("Cannot activate connection in invalid state: {state:?}")
189 })?;
190 }
191
192 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
194 timeout(
195 self.namespace.config.on_connect_handler_timeout,
196 on_connect_handler(self.clone()),
197 )
198 .await??;
199 }
200
201 self.state
203 .try_transition(ConnectionState::Activating, ConnectionState::Ready)?;
204
205 self.namespace.insert_connection(self.clone());
207
208 self.send_packet(&WsIoPacket::new_ready()).await?;
210
211 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
213 self.spawn_task(on_ready_handler(self.clone()));
215 }
216
217 Ok(())
218 }
219
220 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221 self.send_message(self.namespace.encode_packet_to_message(packet)?)
222 .await
223 }
224
225 pub(crate) async fn cleanup(self: &Arc<Self>) {
227 self.state.store(ConnectionState::Closing);
229
230 self.namespace.remove_connection(self.id);
232
233 let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235 for room_name in &joined_rooms {
236 self.namespace.remove_connection_id_from_room(room_name, self.id);
237 }
238
239 self.joined_rooms.clear();
240
241 abort_locked_task(&self.init_timeout_task).await;
243
244 self.cancel_token.load().cancel();
246
247 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249 let _ = timeout(
250 self.namespace.config.on_close_handler_timeout,
251 on_close_handler(self.clone()),
252 )
253 .await;
254 }
255
256 self.state.store(ConnectionState::Closed);
258 }
259
260 #[inline]
261 pub(crate) fn close(&self) {
262 match self.state.get() {
264 ConnectionState::Closed | ConnectionState::Closing => return,
265 _ => self.state.store(ConnectionState::Closing),
266 }
267
268 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270 }
271
272 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273 self.state.ensure(ConnectionState::Ready, |state| {
274 format!("Cannot emit in invalid state: {state:?}")
275 })?;
276
277 self.send_message(message).await
278 }
279
280 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
281 let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
283 match packet.r#type {
284 WsIoPacketType::Event => {
285 if self.is_ready() {
286 if let Some(event) = packet.key.as_deref() {
287 return self.handle_event_packet(event, packet.data);
288 } else {
289 bail!("Event packet missing key");
290 }
291 }
292
293 Ok(())
294 }
295 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
296 _ => Ok(()),
297 }
298 }
299
300 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
301 self.state.ensure(ConnectionState::Created, |state| {
303 format!("Cannot init connection in invalid state: {state:?}")
304 })?;
305
306 let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
308 timeout(
309 self.namespace.config.init_request_handler_timeout,
310 init_request_handler(self.clone(), &self.namespace.config.packet_codec),
311 )
312 .await??
313 } else {
314 None
315 };
316
317 self.state
319 .try_transition(ConnectionState::Created, ConnectionState::AwaitingInit)?;
320
321 let connection = self.clone();
323 *self.init_timeout_task.lock().await = Some(spawn(async move {
324 sleep(connection.namespace.config.init_response_timeout).await;
325 if connection.state.is(ConnectionState::AwaitingInit) {
326 connection.close();
327 }
328 }));
329
330 self.send_packet(&WsIoPacket::new_init(init_request_data)).await
332 }
333
334 pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
335 Ok(self.message_tx.send(message).await?)
336 }
337
338 pub async fn disconnect(&self) {
340 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
341 self.close()
342 }
343
344 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
345 self.emit_event_message(
346 self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
347 event.as_ref(),
348 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
349 .transpose()?,
350 ))?,
351 )
352 .await
353 }
354
355 #[inline]
356 pub fn except(
357 self: &Arc<Self>,
358 room_names: impl IntoIterator<Item = impl Into<String>>,
359 ) -> WsIoServerNamespaceBroadcastOperator {
360 self.namespace.except(room_names).except_connection_ids([self.id])
361 }
362
363 #[cfg(feature = "connection-extensions")]
364 #[inline]
365 pub fn extensions(&self) -> &ConnectionExtensions {
366 &self.extensions
367 }
368
369 #[inline]
370 pub fn headers(&self) -> &HeaderMap {
371 &self.headers
372 }
373
374 #[inline]
375 pub fn id(&self) -> u64 {
376 self.id
377 }
378
379 #[inline]
380 pub fn is_ready(&self) -> bool {
381 self.state.is(ConnectionState::Ready)
382 }
383
384 #[inline]
385 pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl Into<String>>) {
386 for room_name in room_names {
387 let room_name = room_name.into();
388 self.namespace.add_connection_id_to_room(&room_name, self.id);
389 self.joined_rooms.insert(room_name);
390 }
391 }
392
393 #[inline]
394 pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl Into<String>>) {
395 for room_name in room_names {
396 let room_name = &room_name.into();
397 self.namespace.remove_connection_id_from_room(room_name, self.id);
398
399 self.joined_rooms.remove(room_name);
400 }
401 }
402
403 #[inline]
404 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
405 self.namespace.clone()
406 }
407
408 #[inline]
409 pub fn off(&self, event: impl AsRef<str>) {
410 self.event_registry.off(event.as_ref());
411 }
412
413 #[inline]
414 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
415 self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
416 }
417
418 #[inline]
419 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
420 where
421 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
422 Fut: Future<Output = Result<()>> + Send + 'static,
423 D: DeserializeOwned + Send + Sync + 'static,
424 {
425 self.event_registry.on(event.as_ref(), handler)
426 }
427
428 pub async fn on_close<H, Fut>(&self, handler: H)
429 where
430 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
431 Fut: Future<Output = Result<()>> + Send + 'static,
432 {
433 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
434 }
435
436 #[inline]
437 pub fn request_uri(&self) -> &Uri {
438 &self.request_uri
439 }
440
441 #[inline]
442 pub fn server(&self) -> WsIoServer {
443 self.namespace.server()
444 }
445
446 #[inline]
447 pub fn to(
448 self: &Arc<Self>,
449 room_names: impl IntoIterator<Item = impl Into<String>>,
450 ) -> WsIoServerNamespaceBroadcastOperator {
451 self.namespace.to(room_names).except_connection_ids([self.id])
452 }
453}
454
455static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));
457
458#[cfg(test)]
459mod tests {
460 use http::{
461 HeaderMap,
462 Uri,
463 };
464
465 use super::*;
466
467 async fn create_test_connection() -> Arc<WsIoServerConnection> {
468 let server = Arc::new(WsIoServer::builder().build());
469 let namespace = server.new_namespace_builder("/socket").register().unwrap();
470 let (connection, _rx) =
471 WsIoServerConnection::new(HeaderMap::new(), namespace, Uri::from_static("http://localhost"));
472
473 connection
474 }
475
476 #[tokio::test]
477 async fn test_handle_incoming_packet_decode_error() {
478 let connection = create_test_connection().await;
479 let garbage_data = b"obviously not valid json or messagepack";
480 let result = connection.handle_incoming_packet(garbage_data).await;
482 assert!(result.is_err(), "Decoding garbage payload should trigger an error");
483 }
484
485 #[tokio::test]
486 async fn test_handle_init_packet_in_invalid_state() {
487 let connection = create_test_connection().await;
488 assert_eq!(connection.state.get(), ConnectionState::Created);
489
490 let encoded = b"[2,null,null]";
493
494 let result = connection.handle_incoming_packet(encoded).await;
496 assert!(
497 result.is_err(),
498 "Should error because state is Created, not AwaitingInit"
499 );
500
501 assert!(result.unwrap_err().to_string().contains("invalid state"));
502 }
503
504 #[tokio::test]
505 async fn test_handle_event_packet_missing_key() {
506 let connection = create_test_connection().await;
507
508 connection.state.store(ConnectionState::Ready);
510
511 let encoded = b"[1,null,null]";
513
514 let result = connection.handle_incoming_packet(encoded).await;
515 assert!(result.is_err(), "Should bail on missing event key");
516 assert_eq!(result.unwrap_err().to_string(), "Event packet missing key");
517 }
518
519 #[tokio::test]
520 async fn test_connection_close_state_transitions() {
521 let connection = create_test_connection().await;
522 assert_eq!(connection.state.get(), ConnectionState::Created);
523
524 connection.close();
525 assert_eq!(connection.state.get(), ConnectionState::Closing);
526
527 connection.close();
529 assert_eq!(connection.state.get(), ConnectionState::Closing);
530 }
531
532 #[tokio::test]
533 async fn test_connection_cleanup() {
534 let connection = create_test_connection().await;
535 let namespace = connection.namespace();
536
537 namespace.insert_connection(connection.clone());
539 assert_eq!(namespace.connection_count(), 1);
540
541 connection.join(["room_a", "room_b"]);
542 assert!(connection.joined_rooms.contains("room_a"));
543
544 connection.cleanup().await;
545
546 assert_eq!(connection.state.get(), ConnectionState::Closed);
547 assert!(connection.joined_rooms.is_empty());
548 assert_eq!(namespace.connection_count(), 0);
549 }
550}