1use std::sync::{
2 Arc,
3 LazyLock,
4 atomic::{
5 AtomicU64,
6 Ordering,
7 },
8};
9
10use anyhow::{
11 Result,
12 bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16 HeaderMap,
17 Uri,
18};
19use num_enum::{
20 IntoPrimitive,
21 TryFromPrimitive,
22};
23use serde::{
24 Serialize,
25 de::DeserializeOwned,
26};
27use tokio::{
28 spawn,
29 sync::{
30 Mutex,
31 mpsc::{
32 Receiver,
33 Sender,
34 channel,
35 },
36 },
37 task::JoinHandle,
38 time::{
39 sleep,
40 timeout,
41 },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52 WsIoServer,
53 core::{
54 atomic::status::AtomicStatus,
55 channel_capacity_from_websocket_config,
56 event::registry::WsIoEventRegistry,
57 packet::{
58 WsIoPacket,
59 WsIoPacketType,
60 },
61 traits::task::spawner::TaskSpawner,
62 types::{
63 BoxAsyncUnaryResultHandler,
64 hashers::FxDashSet,
65 },
66 utils::task::abort_locked_task,
67 },
68 namespace::{
69 WsIoServerNamespace,
70 operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71 },
72};
73
74#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionStatus {
78 Activating,
79 AwaitingInit,
80 Closed,
81 Closing,
82 Created,
83 Initiating,
84 Ready,
85}
86
87pub struct WsIoServerConnection {
89 cancel_token: ArcSwap<CancellationToken>,
90 event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91 #[cfg(feature = "connection-extensions")]
92 extensions: ConnectionExtensions,
93 headers: HeaderMap,
94 id: u64,
95 init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96 joined_rooms: FxDashSet<String>,
97 message_tx: Sender<Arc<Message>>,
98 namespace: Arc<WsIoServerNamespace>,
99 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100 request_uri: Uri,
101 status: AtomicStatus<ConnectionStatus>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105 #[inline]
106 fn cancel_token(&self) -> Arc<CancellationToken> {
107 self.cancel_token.load_full()
108 }
109}
110
111impl WsIoServerConnection {
112 #[inline]
113 pub(crate) fn new(
114 headers: HeaderMap,
115 namespace: Arc<WsIoServerNamespace>,
116 request_uri: Uri,
117 ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119 let (message_tx, message_rx) = channel(channel_capacity);
120 (
121 Arc::new(Self {
122 cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123 event_registry: WsIoEventRegistry::new(),
124 #[cfg(feature = "connection-extensions")]
125 extensions: ConnectionExtensions::new(),
126 headers,
127 id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128 init_timeout_task: Mutex::new(None),
129 joined_rooms: FxDashSet::default(),
130 message_tx,
131 namespace,
132 on_close_handler: Mutex::new(None),
133 request_uri,
134 status: AtomicStatus::new(ConnectionStatus::Created),
135 }),
136 message_rx,
137 )
138 }
139
140 #[inline]
142 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143 if self.is_ready() {
144 self.event_registry.dispatch_event_packet(
145 self.clone(),
146 event,
147 &self.namespace.config.packet_codec,
148 packet_data,
149 self,
150 );
151 }
152
153 Ok(())
154 }
155
156 async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
157 let status = self.status.get();
159 match status {
160 ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
161 _ => bail!("Received init packet in invalid status: {status:?}"),
162 }
163
164 abort_locked_task(&self.init_timeout_task).await;
166
167 if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
169 timeout(
170 self.namespace.config.init_response_handler_timeout,
171 init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
172 )
173 .await??
174 }
175
176 self.status
178 .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
179
180 if let Some(middleware) = &self.namespace.config.middleware {
182 timeout(
183 self.namespace.config.middleware_execution_timeout,
184 middleware(self.clone()),
185 )
186 .await??;
187
188 self.status.ensure(ConnectionStatus::Activating, |status| {
190 format!("Cannot activate connection in invalid status: {status:?}")
191 })?;
192 }
193
194 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
196 timeout(
197 self.namespace.config.on_connect_handler_timeout,
198 on_connect_handler(self.clone()),
199 )
200 .await??;
201 }
202
203 self.status
205 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
206
207 self.namespace.insert_connection(self.clone());
209
210 self.send_packet(&WsIoPacket::new_ready()).await?;
212
213 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
215 self.spawn_task(on_ready_handler(self.clone()));
217 }
218
219 Ok(())
220 }
221
222 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
223 self.send_message(self.namespace.encode_packet_to_message(packet)?)
224 .await
225 }
226
227 pub(crate) async fn cleanup(self: &Arc<Self>) {
229 self.status.store(ConnectionStatus::Closing);
231
232 self.namespace.remove_connection(self.id);
234
235 let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
237 for room_name in &joined_rooms {
238 self.namespace.remove_connection_id_from_room(room_name, self.id);
239 }
240
241 self.joined_rooms.clear();
242
243 abort_locked_task(&self.init_timeout_task).await;
245
246 self.cancel_token.load().cancel();
248
249 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
251 let _ = timeout(
252 self.namespace.config.on_close_handler_timeout,
253 on_close_handler(self.clone()),
254 )
255 .await;
256 }
257
258 self.status.store(ConnectionStatus::Closed);
260 }
261
262 #[inline]
263 pub(crate) fn close(&self) {
264 match self.status.get() {
266 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
267 _ => self.status.store(ConnectionStatus::Closing),
268 }
269
270 let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
272 }
273
274 pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
275 self.status.ensure(ConnectionStatus::Ready, |status| {
276 format!("Cannot emit in invalid status: {status:?}")
277 })?;
278
279 self.send_message(message).await
280 }
281
282 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
283 let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
285 match packet.r#type {
286 WsIoPacketType::Event => {
287 if let Some(event) = packet.key.as_deref() {
288 self.handle_event_packet(event, packet.data)
289 } else {
290 bail!("Event packet missing key");
291 }
292 }
293 WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
294 _ => Ok(()),
295 }
296 }
297
298 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
299 self.status.ensure(ConnectionStatus::Created, |status| {
301 format!("Cannot init connection in invalid status: {status:?}")
302 })?;
303
304 let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
306 timeout(
307 self.namespace.config.init_request_handler_timeout,
308 init_request_handler(self.clone(), &self.namespace.config.packet_codec),
309 )
310 .await??
311 } else {
312 None
313 };
314
315 self.status
317 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
318
319 let connection = self.clone();
321 *self.init_timeout_task.lock().await = Some(spawn(async move {
322 sleep(connection.namespace.config.init_response_timeout).await;
323 if connection.status.is(ConnectionStatus::AwaitingInit) {
324 connection.close();
325 }
326 }));
327
328 self.send_packet(&WsIoPacket::new_init(init_request_data)).await
330 }
331
332 pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
333 Ok(self.message_tx.send(message).await?)
334 }
335
336 pub async fn disconnect(&self) {
338 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
339 self.close()
340 }
341
342 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
343 self.emit_event_message(
344 self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
345 event.as_ref(),
346 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
347 .transpose()?,
348 ))?,
349 )
350 .await
351 }
352
353 #[inline]
354 pub fn except(
355 self: &Arc<Self>,
356 room_names: impl IntoIterator<Item = impl AsRef<str>>,
357 ) -> WsIoServerNamespaceBroadcastOperator {
358 self.namespace.except(room_names).except_connection_ids(vec![self.id])
359 }
360
361 #[cfg(feature = "connection-extensions")]
362 #[inline]
363 pub fn extensions(&self) -> &ConnectionExtensions {
364 &self.extensions
365 }
366
367 #[inline]
368 pub fn headers(&self) -> &HeaderMap {
369 &self.headers
370 }
371
372 #[inline]
373 pub fn id(&self) -> u64 {
374 self.id
375 }
376
377 #[inline]
378 pub fn is_ready(&self) -> bool {
379 self.status.is(ConnectionStatus::Ready)
380 }
381
382 #[inline]
383 pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
384 for room_name in room_names {
385 let room_name = room_name.as_ref();
386 self.namespace.add_connection_id_to_room(room_name, self.id);
387 self.joined_rooms.insert(room_name.into());
388 }
389 }
390
391 #[inline]
392 pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
393 for room_name in room_names {
394 self.namespace
395 .remove_connection_id_from_room(room_name.as_ref(), self.id);
396
397 self.joined_rooms.remove(room_name.as_ref());
398 }
399 }
400
401 #[inline]
402 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
403 self.namespace.clone()
404 }
405
406 #[inline]
407 pub fn off(&self, event: impl AsRef<str>) {
408 self.event_registry.off(event.as_ref());
409 }
410
411 #[inline]
412 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
413 self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
414 }
415
416 #[inline]
417 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
418 where
419 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
420 Fut: Future<Output = Result<()>> + Send + 'static,
421 D: DeserializeOwned + Send + Sync + 'static,
422 {
423 self.event_registry.on(event.as_ref(), handler)
424 }
425
426 pub async fn on_close<H, Fut>(&self, handler: H)
427 where
428 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
429 Fut: Future<Output = Result<()>> + Send + 'static,
430 {
431 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
432 }
433
434 #[inline]
435 pub fn request_uri(&self) -> &Uri {
436 &self.request_uri
437 }
438
439 #[inline]
440 pub fn server(&self) -> WsIoServer {
441 self.namespace.server()
442 }
443
444 #[inline]
445 pub fn to(
446 self: &Arc<Self>,
447 room_names: impl IntoIterator<Item = impl AsRef<str>>,
448 ) -> WsIoServerNamespaceBroadcastOperator {
449 self.namespace.to(room_names).except_connection_ids(vec![self.id])
450 }
451}
452
453static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));