1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use http::HeaderMap;
8use num_enum::{
9 IntoPrimitive,
10 TryFromPrimitive,
11};
12use serde::Serialize;
13use tokio::{
14 select,
15 spawn,
16 sync::{
17 Mutex,
18 mpsc::{
19 Receiver,
20 Sender,
21 channel,
22 },
23 },
24 task::JoinHandle,
25 time::{
26 sleep,
27 timeout,
28 },
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32
33#[cfg(feature = "connection-extensions")]
34mod extensions;
35
36#[cfg(feature = "connection-extensions")]
37use self::extensions::WsIoServerConnectionExtensions;
38use crate::{
39 WsIoServer,
40 core::{
41 atomic::status::AtomicStatus,
42 packet::{
43 WsIoPacket,
44 WsIoPacketType,
45 },
46 types::BoxAsyncUnaryResultHandler,
47 utils::task::abort_locked_task,
48 },
49 namespace::WsIoServerNamespace,
50};
51
52#[repr(u8)]
53#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
54enum ConnectionStatus {
55 Activating,
56 Authenticating,
57 AwaitingAuth,
58 Closed,
59 Closing,
60 Created,
61 Ready,
62}
63
64pub struct WsIoServerConnection {
65 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
66 cancel_token: CancellationToken,
67 #[cfg(feature = "connection-extensions")]
68 extensions: WsIoServerConnectionExtensions,
69 headers: HeaderMap,
70 message_tx: Sender<Message>,
71 namespace: Arc<WsIoServerNamespace>,
72 on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
73 sid: String,
74 status: AtomicStatus<ConnectionStatus>,
75}
76
77impl WsIoServerConnection {
78 pub(crate) fn new(
79 headers: HeaderMap,
80 namespace: Arc<WsIoServerNamespace>,
81 sid: String,
82 ) -> (Arc<Self>, Receiver<Message>) {
83 let channel_capacity = (namespace.config.websocket_config.max_write_buffer_size
84 / namespace.config.websocket_config.write_buffer_size)
85 .clamp(64, 4096);
86
87 let (message_tx, message_rx) = channel(channel_capacity);
88 (
89 Arc::new(Self {
90 auth_timeout_task: Mutex::new(None),
91 cancel_token: CancellationToken::new(),
92 #[cfg(feature = "connection-extensions")]
93 extensions: WsIoServerConnectionExtensions::new(),
94 headers,
95 message_tx,
96 namespace,
97 on_close_handler: Mutex::new(None),
98 sid,
99 status: AtomicStatus::new(ConnectionStatus::Created),
100 }),
101 message_rx,
102 )
103 }
104
105 async fn activate(self: &Arc<Self>) -> Result<()> {
107 let status = self.status.get();
109 match status {
110 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
111 self.status.try_transition(status, ConnectionStatus::Activating)?
112 }
113 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
114 }
115
116 if let Some(middleware) = &self.namespace.config.middleware {
118 timeout(
119 self.namespace.config.middleware_execution_timeout,
120 middleware(self.clone()),
121 )
122 .await??;
123 }
124
125 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
127 timeout(
128 self.namespace.config.on_connect_handler_timeout,
129 on_connect_handler(self.clone()),
130 )
131 .await??;
132 }
133
134 self.namespace.insert_connection(self.clone());
136
137 self.status
139 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
140
141 self.send_packet(&WsIoPacket {
143 data: None,
144 key: None,
145 r#type: WsIoPacketType::Ready,
146 })
147 .await?;
148
149 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
151 let connection = self.clone();
153 self.spawn_task(async move { on_ready_handler(connection).await });
154 }
155
156 Ok(())
157 }
158
159 async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
160 let status = self.status.get();
162 match status {
163 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
164 _ => bail!("Received auth packet in invalid status: {:#?}", status),
165 }
166
167 abort_locked_task(&self.auth_timeout_task).await;
169
170 if let Some(auth_handler) = &self.namespace.config.auth_handler {
172 timeout(
173 self.namespace.config.auth_handler_timeout,
174 auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
175 )
176 .await??;
177
178 self.activate().await
180 } else {
181 bail!("Auth packet received but no auth handler is configured");
182 }
183 }
184
185 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
186 Ok(self
187 .message_tx
188 .send(self.namespace.encode_packet_to_message(packet)?)
189 .await?)
190 }
191
192 pub(crate) async fn cleanup(self: &Arc<Self>) {
194 self.status.store(ConnectionStatus::Closing);
196
197 abort_locked_task(&self.auth_timeout_task).await;
199
200 self.namespace.remove_connection(&self.sid);
202
203 self.cancel_token.cancel();
205
206 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
208 let _ = timeout(
209 self.namespace.config.on_close_handler_timeout,
210 on_close_handler(self.clone()),
211 )
212 .await;
213 }
214
215 self.status.store(ConnectionStatus::Closed);
217 }
218
219 #[inline]
220 pub(crate) fn close(&self) {
221 match self.status.get() {
223 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
224 _ => self.status.store(ConnectionStatus::Closing),
225 }
226
227 let _ = self.message_tx.try_send(Message::Close(None));
229 }
230
231 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
232 let packet = self.namespace.config.packet_codec.decode(bytes)?;
233 match packet.r#type {
234 WsIoPacketType::Auth => {
235 if let Some(packet_data) = packet.data.as_deref() {
236 self.handle_auth_packet(packet_data).await
237 } else {
238 bail!("Auth packet missing data");
239 }
240 }
241 _ => Ok(()),
242 }
243 }
244
245 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
246 let status = self.status.get();
248 if !matches!(status, ConnectionStatus::Created) {
249 bail!("Cannot init connection in invalid status: {:#?}", status);
250 }
251
252 let requires_auth = self.namespace.config.auth_handler.is_some();
254
255 let packet = WsIoPacket {
257 data: Some(self.namespace.config.packet_codec.encode_data(&requires_auth)?),
258 key: None,
259 r#type: WsIoPacketType::Init,
260 };
261
262 if requires_auth {
264 self.status
266 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
267
268 let connection = self.clone();
270 *self.auth_timeout_task.lock().await = Some(spawn(async move {
271 sleep(connection.namespace.config.auth_packet_timeout).await;
272 if connection.status.is(ConnectionStatus::AwaitingAuth) {
273 connection.close();
274 }
275 }));
276
277 self.send_packet(&packet).await
279 } else {
280 self.send_packet(&packet).await?;
282
283 self.activate().await
285 }
286 }
287
288 #[inline]
291 pub fn cancel_token(&self) -> &CancellationToken {
292 &self.cancel_token
293 }
294
295 pub async fn disconnect(&self) {
296 let _ = self
297 .send_packet(&WsIoPacket {
298 data: None,
299 key: None,
300 r#type: WsIoPacketType::Disconnect,
301 })
302 .await;
303
304 self.close()
305 }
306
307 pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
308 let status = self.status.get();
309 if status != ConnectionStatus::Ready {
310 bail!("Cannot emit event in invalid status: {:#?}", status);
311 }
312
313 self.send_packet(&WsIoPacket {
314 data: data
315 .map(|data| self.namespace.config.packet_codec.encode_data(data))
316 .transpose()?,
317 key: Some(event.into()),
318 r#type: WsIoPacketType::Event,
319 })
320 .await
321 }
322
323 #[cfg(feature = "connection-extensions")]
324 #[inline]
325 pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
326 &self.extensions
327 }
328
329 #[inline]
330 pub fn headers(&self) -> &HeaderMap {
331 &self.headers
332 }
333
334 #[inline]
335 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
336 self.namespace.clone()
337 }
338
339 pub async fn on_close<H, Fut>(&self, handler: H)
340 where
341 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
342 Fut: Future<Output = Result<()>> + Send + 'static,
343 {
344 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
345 }
346
347 #[inline]
348 pub fn server(&self) -> WsIoServer {
349 self.namespace.server()
350 }
351
352 #[inline]
353 pub fn sid(&self) -> &str {
354 &self.sid
355 }
356
357 #[inline]
358 pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
359 let cancel_token = self.cancel_token.clone();
360 spawn(async move {
361 select! {
362 _ = cancel_token.cancelled() => {},
363 _ = future => {},
364 }
365 });
366 }
367}