1use std::{
2 pin::Pin,
3 sync::Arc,
4};
5
6use anyhow::{
7 Result,
8 bail,
9};
10use http::HeaderMap;
11use num_enum::{
12 IntoPrimitive,
13 TryFromPrimitive,
14};
15use serde::Serialize;
16use tokio::{
17 select,
18 spawn,
19 sync::{
20 Mutex,
21 mpsc::{
22 Receiver,
23 Sender,
24 channel,
25 },
26 },
27 task::JoinHandle,
28 time::sleep,
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32
33#[cfg(feature = "connection-extensions")]
34mod extensions;
35
36#[cfg(feature = "connection-extensions")]
37use self::extensions::WsIoServerConnectionExtensions;
38use crate::{
39 WsIoServer,
40 core::{
41 atomic::status::AtomicStatus,
42 packet::{
43 WsIoPacket,
44 WsIoPacketType,
45 },
46 },
47 namespace::WsIoServerNamespace,
48};
49
50#[repr(u8)]
51#[derive(Debug, IntoPrimitive, TryFromPrimitive)]
52enum ConnectionStatus {
53 Activating,
54 Authenticating,
55 AwaitingAuth,
56 Closed,
57 Closing,
58 Created,
59 Ready,
60}
61
62type OnCloseHandler = Box<
63 dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
64 + Send
65 + Sync
66 + 'static,
67>;
68
69pub struct WsIoServerConnection {
70 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
71 cancel_token: CancellationToken,
72 #[cfg(feature = "connection-extensions")]
73 extensions: WsIoServerConnectionExtensions,
74 headers: HeaderMap,
75 message_tx: Sender<Message>,
76 namespace: Arc<WsIoServerNamespace>,
77 on_close_handler: Mutex<Option<OnCloseHandler>>,
78 sid: String,
79 status: AtomicStatus<ConnectionStatus>,
80}
81
82impl WsIoServerConnection {
83 pub(crate) fn new(
84 headers: HeaderMap,
85 namespace: Arc<WsIoServerNamespace>,
86 sid: String,
87 ) -> (Arc<Self>, Receiver<Message>) {
88 let (message_tx, message_rx) = channel(512);
90 (
91 Arc::new(Self {
92 auth_timeout_task: Mutex::new(None),
93 cancel_token: CancellationToken::new(),
94 #[cfg(feature = "connection-extensions")]
95 extensions: WsIoServerConnectionExtensions::new(),
96 headers,
97 message_tx,
98 namespace,
99 on_close_handler: Mutex::new(None),
100 sid,
101 status: AtomicStatus::new(ConnectionStatus::Created),
102 }),
103 message_rx,
104 )
105 }
106
107 async fn abort_auth_timeout_task(&self) {
109 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
110 auth_timeout_task.abort();
111 }
112 }
113
114 async fn activate(self: &Arc<Self>) -> Result<()> {
115 let status = self.status.get();
116 match status {
117 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
118 self.status.try_transition(status, ConnectionStatus::Activating)?
119 }
120 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
121 }
122
123 if let Some(middleware) = &self.namespace.config.middleware {
124 middleware(self.clone()).await?;
125 }
126
127 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
128 on_connect_handler(self.clone()).await?;
129 }
130
131 self.namespace.insert_connection(self.clone());
132 self.send_packet(&WsIoPacket {
133 data: None,
134 key: None,
135 r#type: WsIoPacketType::Ready,
136 })
137 .await?;
138
139 self.status
140 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
141
142 if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
143 on_ready_handler(self.clone()).await?;
144 }
145
146 Ok(())
147 }
148
149 async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
150 let status = self.status.get();
151 match status {
152 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
153 _ => bail!("Received auth packet in invalid status: {:#?}", status),
154 }
155
156 if let Some(auth_handler) = &self.namespace.config.auth_handler {
157 (auth_handler)(self.clone(), packet_data).await?;
158 self.abort_auth_timeout_task().await;
159 self.activate().await
160 } else {
161 bail!("Auth packet received but no auth handler is configured");
162 }
163 }
164
165 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
166 Ok(self
167 .message_tx
168 .send(self.namespace.encode_packet_to_message(packet)?)
169 .await?)
170 }
171
172 pub(crate) async fn cleanup(self: &Arc<Self>) {
174 self.status.store(ConnectionStatus::Closing);
175 self.abort_auth_timeout_task().await;
176 self.namespace.remove_connection(&self.sid);
177 self.cancel_token.cancel();
178 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
179 let _ = on_close_handler(self.clone()).await;
180 }
181
182 self.status.store(ConnectionStatus::Closed);
183 }
184
185 pub(crate) async fn close(&self) {
186 match self.status.get() {
187 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
188 _ => self.status.store(ConnectionStatus::Closing),
189 }
190
191 let _ = self.message_tx.send(Message::Close(None)).await;
192 }
193
194 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
195 let packet = self.namespace.config.packet_codec.decode(bytes)?;
196 match packet.r#type {
197 WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
198 _ => Ok(()),
199 }
200 }
201
202 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
203 let require_auth = self.namespace.config.auth_handler.is_some();
204 let packet = WsIoPacket {
205 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
206 key: None,
207 r#type: WsIoPacketType::Init,
208 };
209
210 if require_auth {
211 self.status.store(ConnectionStatus::AwaitingAuth);
212 let connection = self.clone();
213 *self.auth_timeout_task.lock().await = Some(spawn(async move {
214 sleep(connection.namespace.config.auth_timeout).await;
215 if matches!(connection.status.get(), ConnectionStatus::AwaitingAuth) {
216 connection.close().await;
217 }
218 }));
219
220 self.send_packet(&packet).await
221 } else {
222 self.send_packet(&packet).await?;
223 self.activate().await
224 }
225 }
226
227 #[inline]
230 pub fn cancel_token(&self) -> &CancellationToken {
231 &self.cancel_token
232 }
233
234 pub async fn disconnect(&self) {
235 let _ = self
236 .send_packet(&WsIoPacket {
237 data: None,
238 key: None,
239 r#type: WsIoPacketType::Disconnect,
240 })
241 .await;
242
243 self.close().await
244 }
245
246 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
247 self.send_packet(&WsIoPacket {
248 data: data
249 .map(|data| self.namespace.config.packet_codec.encode_data(data))
250 .transpose()?,
251 key: Some(event.as_ref().to_string()),
252 r#type: WsIoPacketType::Event,
253 })
254 .await
255 }
256
257 #[cfg(feature = "connection-extensions")]
258 #[inline]
259 pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
260 &self.extensions
261 }
262
263 #[inline]
264 pub fn headers(&self) -> &HeaderMap {
265 &self.headers
266 }
267
268 #[inline]
269 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
270 self.namespace.clone()
271 }
272
273 pub async fn on_close<H, Fut>(&self, handler: H)
274 where
275 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
276 Fut: Future<Output = Result<()>> + Send + 'static,
277 {
278 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
279 }
280
281 #[inline]
282 pub fn server(&self) -> WsIoServer {
283 self.namespace.server()
284 }
285
286 #[inline]
287 pub fn sid(&self) -> &str {
288 &self.sid
289 }
290
291 #[inline]
292 pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
293 let cancel_token = self.cancel_token().clone();
294 spawn(async move {
295 select! {
296 _ = cancel_token.cancelled() => {},
297 _ = future => {},
298 }
299 });
300 }
301}