wsio_server/connection/
mod.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use http::HeaderMap;
8use num_enum::{
9    IntoPrimitive,
10    TryFromPrimitive,
11};
12use serde::Serialize;
13use tokio::{
14    select,
15    spawn,
16    sync::{
17        Mutex,
18        mpsc::{
19            Receiver,
20            Sender,
21            channel,
22        },
23    },
24    task::JoinHandle,
25    time::{
26        sleep,
27        timeout,
28    },
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32
33#[cfg(feature = "connection-extensions")]
34mod extensions;
35
36#[cfg(feature = "connection-extensions")]
37use self::extensions::WsIoServerConnectionExtensions;
38use crate::{
39    WsIoServer,
40    core::{
41        atomic::status::AtomicStatus,
42        packet::{
43            WsIoPacket,
44            WsIoPacketType,
45        },
46        types::BoxAsyncUnaryResultHandler,
47        utils::task::abort_locked_task,
48    },
49    namespace::WsIoServerNamespace,
50};
51
52#[repr(u8)]
53#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
54enum ConnectionStatus {
55    Activating,
56    Authenticating,
57    AwaitingAuth,
58    Closed,
59    Closing,
60    Created,
61    Ready,
62}
63
64pub struct WsIoServerConnection {
65    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
66    cancel_token: CancellationToken,
67    #[cfg(feature = "connection-extensions")]
68    extensions: WsIoServerConnectionExtensions,
69    headers: HeaderMap,
70    message_tx: Sender<Message>,
71    namespace: Arc<WsIoServerNamespace>,
72    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
73    sid: String,
74    status: AtomicStatus<ConnectionStatus>,
75}
76
77impl WsIoServerConnection {
78    pub(crate) fn new(
79        headers: HeaderMap,
80        namespace: Arc<WsIoServerNamespace>,
81        sid: String,
82    ) -> (Arc<Self>, Receiver<Message>) {
83        let channel_capacity = (namespace.config.websocket_config.max_write_buffer_size
84            / namespace.config.websocket_config.write_buffer_size)
85            .clamp(64, 4096);
86
87        let (message_tx, message_rx) = channel(channel_capacity);
88        (
89            Arc::new(Self {
90                auth_timeout_task: Mutex::new(None),
91                cancel_token: CancellationToken::new(),
92                #[cfg(feature = "connection-extensions")]
93                extensions: WsIoServerConnectionExtensions::new(),
94                headers,
95                message_tx,
96                namespace,
97                on_close_handler: Mutex::new(None),
98                sid,
99                status: AtomicStatus::new(ConnectionStatus::Created),
100            }),
101            message_rx,
102        )
103    }
104
105    // Private methods
106    async fn activate(self: &Arc<Self>) -> Result<()> {
107        // Verify current state; only valid from Authenticating or Created → Activating
108        let status = self.status.get();
109        match status {
110            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
111                self.status.try_transition(status, ConnectionStatus::Activating)?
112            }
113            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
114        }
115
116        // Invoke middleware with timeout protection if configured
117        if let Some(middleware) = &self.namespace.config.middleware {
118            timeout(
119                self.namespace.config.middleware_execution_timeout,
120                middleware(self.clone()),
121            )
122            .await??;
123        }
124
125        // Invoke on_connect handler with timeout protection if configured
126        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
127            timeout(
128                self.namespace.config.on_connect_handler_timeout,
129                on_connect_handler(self.clone()),
130            )
131            .await??;
132        }
133
134        // Insert connection into namespace
135        self.namespace.insert_connection(self.clone());
136
137        // Transition state to Ready
138        self.status
139            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
140
141        // Send ready packet
142        self.send_packet(&WsIoPacket {
143            data: None,
144            key: None,
145            r#type: WsIoPacketType::Ready,
146        })
147        .await?;
148
149        // Invoke on_ready handler if configured
150        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
151            // Run handler asynchronously in a detached task
152            let connection = self.clone();
153            self.spawn_task(async move { on_ready_handler(connection).await });
154        }
155
156        Ok(())
157    }
158
159    async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
160        // Verify current state; only valid from AwaitingAuth → Authenticating
161        let status = self.status.get();
162        match status {
163            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
164            _ => bail!("Received auth packet in invalid status: {:#?}", status),
165        }
166
167        // Abort auth-timeout task if still active
168        abort_locked_task(&self.auth_timeout_task).await;
169
170        // Invoke auth handler with timeout protection if configured, otherwise raise error
171        if let Some(auth_handler) = &self.namespace.config.auth_handler {
172            timeout(
173                self.namespace.config.auth_handler_timeout,
174                auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
175            )
176            .await??;
177
178            // Activate connection
179            self.activate().await
180        } else {
181            bail!("Auth packet received but no auth handler is configured");
182        }
183    }
184
185    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
186        Ok(self
187            .message_tx
188            .send(self.namespace.encode_packet_to_message(packet)?)
189            .await?)
190    }
191
192    // Protected methods
193    pub(crate) async fn cleanup(self: &Arc<Self>) {
194        // Set connection state to Closing
195        self.status.store(ConnectionStatus::Closing);
196
197        // Abort auth-timeout task if still active
198        abort_locked_task(&self.auth_timeout_task).await;
199
200        // Remove connection from namespace
201        self.namespace.remove_connection(&self.sid);
202
203        // Cancel all ongoing operations via cancel token
204        self.cancel_token.cancel();
205
206        // Invoke on_close handler with timeout protection if configured
207        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
208            let _ = timeout(
209                self.namespace.config.on_close_handler_timeout,
210                on_close_handler(self.clone()),
211            )
212            .await;
213        }
214
215        // Set connection state to Closed
216        self.status.store(ConnectionStatus::Closed);
217    }
218
219    #[inline]
220    pub(crate) fn close(&self) {
221        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
222        match self.status.get() {
223            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
224            _ => self.status.store(ConnectionStatus::Closing),
225        }
226
227        // Send websocket close frame to initiate graceful shutdown
228        let _ = self.message_tx.try_send(Message::Close(None));
229    }
230
231    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
232        let packet = self.namespace.config.packet_codec.decode(bytes)?;
233        match packet.r#type {
234            WsIoPacketType::Auth => {
235                if let Some(packet_data) = packet.data.as_deref() {
236                    self.handle_auth_packet(packet_data).await
237                } else {
238                    bail!("Auth packet missing data");
239                }
240            }
241            _ => Ok(()),
242        }
243    }
244
245    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
246        // Verify current state; only valid Created
247        let status = self.status.get();
248        if !matches!(status, ConnectionStatus::Created) {
249            bail!("Cannot init connection in invalid status: {:#?}", status);
250        }
251
252        // Determine if authentication is required
253        let requires_auth = self.namespace.config.auth_handler.is_some();
254
255        // Build Init packet to inform client whether auth is required
256        let packet = WsIoPacket {
257            data: Some(self.namespace.config.packet_codec.encode_data(&requires_auth)?),
258            key: None,
259            r#type: WsIoPacketType::Init,
260        };
261
262        // If authentication is required
263        if requires_auth {
264            // Transition state to AwaitingAuth
265            self.status
266                .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
267
268            // Spawn auth-packet-timeout watchdog to close connection if auth not received in time
269            let connection = self.clone();
270            *self.auth_timeout_task.lock().await = Some(spawn(async move {
271                sleep(connection.namespace.config.auth_packet_timeout).await;
272                if connection.status.is(ConnectionStatus::AwaitingAuth) {
273                    connection.close();
274                }
275            }));
276
277            // Send Init packet to client (expecting auth response)
278            self.send_packet(&packet).await
279        } else {
280            // Send Init packet to client (no auth required)
281            self.send_packet(&packet).await?;
282
283            // Immediately activate connection
284            self.activate().await
285        }
286    }
287
288    // Public methods
289
290    #[inline]
291    pub fn cancel_token(&self) -> &CancellationToken {
292        &self.cancel_token
293    }
294
295    pub async fn disconnect(&self) {
296        let _ = self
297            .send_packet(&WsIoPacket {
298                data: None,
299                key: None,
300                r#type: WsIoPacketType::Disconnect,
301            })
302            .await;
303
304        self.close()
305    }
306
307    pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
308        let status = self.status.get();
309        if status != ConnectionStatus::Ready {
310            bail!("Cannot emit event in invalid status: {:#?}", status);
311        }
312
313        self.send_packet(&WsIoPacket {
314            data: data
315                .map(|data| self.namespace.config.packet_codec.encode_data(data))
316                .transpose()?,
317            key: Some(event.into()),
318            r#type: WsIoPacketType::Event,
319        })
320        .await
321    }
322
323    #[cfg(feature = "connection-extensions")]
324    #[inline]
325    pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
326        &self.extensions
327    }
328
329    #[inline]
330    pub fn headers(&self) -> &HeaderMap {
331        &self.headers
332    }
333
334    #[inline]
335    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
336        self.namespace.clone()
337    }
338
339    pub async fn on_close<H, Fut>(&self, handler: H)
340    where
341        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
342        Fut: Future<Output = Result<()>> + Send + 'static,
343    {
344        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
345    }
346
347    #[inline]
348    pub fn server(&self) -> WsIoServer {
349        self.namespace.server()
350    }
351
352    #[inline]
353    pub fn sid(&self) -> &str {
354        &self.sid
355    }
356
357    #[inline]
358    pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
359        let cancel_token = self.cancel_token.clone();
360        spawn(async move {
361            select! {
362                _ = cancel_token.cancelled() => {},
363                _ = future => {},
364            }
365        });
366    }
367}