1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use http::HeaderMap;
8use num_enum::{
9 IntoPrimitive,
10 TryFromPrimitive,
11};
12use serde::{
13 Serialize,
14 de::DeserializeOwned,
15};
16use tokio::{
17 select,
18 spawn,
19 sync::{
20 Mutex,
21 mpsc::{
22 Receiver,
23 Sender,
24 channel,
25 },
26 },
27 task::JoinHandle,
28 time::{
29 sleep,
30 timeout,
31 },
32};
33use tokio_tungstenite::tungstenite::Message;
34use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "connection-extensions")]
37mod extensions;
38
39#[cfg(feature = "connection-extensions")]
40use self::extensions::ConnectionExtensions;
41use crate::{
42 WsIoServer,
43 core::{
44 atomic::status::AtomicStatus,
45 channel_capacity_from_websocket_config,
46 event::registry::WsIoEventRegistry,
47 packet::{
48 WsIoPacket,
49 WsIoPacketType,
50 },
51 types::BoxAsyncUnaryResultHandler,
52 utils::task::abort_locked_task,
53 },
54 namespace::WsIoServerNamespace,
55};
56
57#[repr(u8)]
58#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
59enum ConnectionStatus {
60 Activating,
61 Authenticating,
62 AwaitingAuth,
63 Closed,
64 Closing,
65 Created,
66 Ready,
67}
68
69pub struct WsIoServerConnection {
70 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
71 cancel_token: CancellationToken,
72 event_registry: WsIoEventRegistry<WsIoServerConnection>,
73 #[cfg(feature = "connection-extensions")]
74 extensions: ConnectionExtensions,
75 headers: HeaderMap,
76 message_tx: Sender<Message>,
77 namespace: Arc<WsIoServerNamespace>,
78 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
79 sid: String,
80 status: AtomicStatus<ConnectionStatus>,
81}
82
83impl WsIoServerConnection {
84 #[inline]
85 pub(crate) fn new(
86 headers: HeaderMap,
87 namespace: Arc<WsIoServerNamespace>,
88 sid: String,
89 ) -> (Arc<Self>, Receiver<Message>) {
90 let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
91 let (message_tx, message_rx) = channel(channel_capacity);
92 (
93 Arc::new(Self {
94 auth_timeout_task: Mutex::new(None),
95 cancel_token: CancellationToken::new(),
96 event_registry: WsIoEventRegistry::new(namespace.config.packet_codec),
97 #[cfg(feature = "connection-extensions")]
98 extensions: ConnectionExtensions::new(),
99 headers,
100 message_tx,
101 namespace,
102 on_close_handler: Mutex::new(None),
103 sid,
104 status: AtomicStatus::new(ConnectionStatus::Created),
105 }),
106 message_rx,
107 )
108 }
109
110 async fn activate(self: &Arc<Self>) -> Result<()> {
112 let status = self.status.get();
114 match status {
115 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
116 self.status.try_transition(status, ConnectionStatus::Activating)?
117 }
118 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
119 }
120
121 if let Some(middleware) = &self.namespace.config.middleware {
123 timeout(
124 self.namespace.config.middleware_execution_timeout,
125 middleware(self.clone()),
126 )
127 .await??;
128 }
129
130 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
132 timeout(
133 self.namespace.config.on_connect_handler_timeout,
134 on_connect_handler(self.clone()),
135 )
136 .await??;
137 }
138
139 self.namespace.insert_connection(self.clone());
141
142 self.status
144 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
145
146 self.send_packet(&WsIoPacket::new_ready()).await?;
148
149 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
151 let connection = self.clone();
153 self.spawn_task(async move { on_ready_handler(connection).await });
154 }
155
156 Ok(())
157 }
158
159 #[inline]
160 fn ensure_status_ready(&self) -> Result<()> {
161 let status = self.status.get();
162 if status != ConnectionStatus::Ready {
163 bail!("Cannot emit event in invalid status: {:#?}", status);
164 }
165
166 Ok(())
167 }
168
169 async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
170 let status = self.status.get();
172 match status {
173 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
174 _ => bail!("Received auth packet in invalid status: {:#?}", status),
175 }
176
177 abort_locked_task(&self.auth_timeout_task).await;
179
180 if let Some(auth_handler) = &self.namespace.config.auth_handler {
182 timeout(
183 self.namespace.config.auth_handler_timeout,
184 auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
185 )
186 .await??;
187
188 self.activate().await
190 } else {
191 bail!("Auth packet received but no auth handler is configured");
192 }
193 }
194
195 #[inline]
196 fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
197 self.event_registry
198 .dispatch_event_packet(self.clone(), event, packet_data);
199
200 Ok(())
201 }
202
203 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
204 Ok(self
205 .message_tx
206 .send(self.namespace.encode_packet_to_message(packet)?)
207 .await?)
208 }
209
210 pub(crate) async fn cleanup(self: &Arc<Self>) {
212 self.status.store(ConnectionStatus::Closing);
214
215 abort_locked_task(&self.auth_timeout_task).await;
217
218 self.namespace.remove_connection(&self.sid);
220
221 self.cancel_token.cancel();
223
224 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
226 let _ = timeout(
227 self.namespace.config.on_close_handler_timeout,
228 on_close_handler(self.clone()),
229 )
230 .await;
231 }
232
233 self.status.store(ConnectionStatus::Closed);
235 }
236
237 #[inline]
238 pub(crate) fn close(&self) {
239 match self.status.get() {
241 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
242 _ => self.status.store(ConnectionStatus::Closing),
243 }
244
245 let _ = self.message_tx.try_send(Message::Close(None));
247 }
248
249 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
250 let packet = self.namespace.config.packet_codec.decode(bytes)?;
252 match packet.r#type {
253 WsIoPacketType::Auth => {
254 if let Some(packet_data) = packet.data.as_deref() {
255 self.handle_auth_packet(packet_data).await
256 } else {
257 bail!("Auth packet missing data");
258 }
259 }
260 WsIoPacketType::Event => {
261 if let Some(event) = packet.key.as_deref() {
262 self.handle_event_packet(event, packet.data)
263 } else {
264 bail!("Event packet missing key");
265 }
266 }
267 _ => Ok(()),
268 }
269 }
270
271 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
272 let status = self.status.get();
274 if !matches!(status, ConnectionStatus::Created) {
275 bail!("Cannot init connection in invalid status: {:#?}", status);
276 }
277
278 let requires_auth = self.namespace.config.auth_handler.is_some();
280
281 let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
283
284 if requires_auth {
286 self.status
288 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
289
290 let connection = self.clone();
292 *self.auth_timeout_task.lock().await = Some(spawn(async move {
293 sleep(connection.namespace.config.auth_packet_timeout).await;
294 if connection.status.is(ConnectionStatus::AwaitingAuth) {
295 connection.close();
296 }
297 }));
298
299 self.send_packet(packet).await
301 } else {
302 self.send_packet(packet).await?;
304
305 self.activate().await
307 }
308 }
309
310 #[inline]
313 pub fn cancel_token(&self) -> &CancellationToken {
314 &self.cancel_token
315 }
316
317 pub async fn disconnect(&self) {
318 let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
319 self.close()
320 }
321
322 pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
323 self.ensure_status_ready()?;
324 self.send_packet(&WsIoPacket::new_event(
325 event,
326 data.map(|data| self.namespace.config.packet_codec.encode_data(data))
327 .transpose()?,
328 ))
329 .await
330 }
331
332 pub async fn emit_message(&self, message: Message) -> Result<()> {
333 self.ensure_status_ready()?;
334 Ok(self.message_tx.send(message).await?)
335 }
336
337 #[cfg(feature = "connection-extensions")]
338 #[inline]
339 pub fn extensions(&self) -> &ConnectionExtensions {
340 &self.extensions
341 }
342
343 #[inline]
344 pub fn headers(&self) -> &HeaderMap {
345 &self.headers
346 }
347
348 #[inline]
349 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
350 self.namespace.clone()
351 }
352
353 #[inline]
354 pub fn off(&self, event: impl AsRef<str>) {
355 self.event_registry.off(event);
356 }
357
358 #[inline]
359 pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
360 self.event_registry.off_by_handler_id(event, handler_id);
361 }
362
363 #[inline]
364 pub fn on<H, Fut, D>(&self, event: impl Into<String>, handler: H) -> u32
365 where
366 H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
367 Fut: Future<Output = Result<()>> + Send + 'static,
368 D: DeserializeOwned + Send + Sync + 'static,
369 {
370 self.event_registry.on(event, handler)
371 }
372
373 pub async fn on_close<H, Fut>(&self, handler: H)
374 where
375 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
376 Fut: Future<Output = Result<()>> + Send + 'static,
377 {
378 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
379 }
380
381 #[inline]
382 pub fn server(&self) -> WsIoServer {
383 self.namespace.server()
384 }
385
386 #[inline]
387 pub fn sid(&self) -> &str {
388 &self.sid
389 }
390
391 #[inline]
392 pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
393 let cancel_token = self.cancel_token.clone();
394 spawn(async move {
395 select! {
396 _ = cancel_token.cancelled() => {},
397 _ = future => {},
398 }
399 });
400 }
401}