wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::HeaderMap;
16use num_enum::{
17    IntoPrimitive,
18    TryFromPrimitive,
19};
20use serde::{
21    Serialize,
22    de::DeserializeOwned,
23};
24use tokio::{
25    spawn,
26    sync::{
27        Mutex,
28        mpsc::{
29            Receiver,
30            Sender,
31            channel,
32        },
33    },
34    task::JoinHandle,
35    time::{
36        sleep,
37        timeout,
38    },
39};
40use tokio_tungstenite::tungstenite::Message;
41use tokio_util::sync::CancellationToken;
42
43#[cfg(feature = "connection-extensions")]
44mod extensions;
45
46#[cfg(feature = "connection-extensions")]
47use self::extensions::ConnectionExtensions;
48use crate::{
49    WsIoServer,
50    core::{
51        atomic::status::AtomicStatus,
52        channel_capacity_from_websocket_config,
53        event::registry::WsIoEventRegistry,
54        packet::{
55            WsIoPacket,
56            WsIoPacketType,
57        },
58        traits::task::spawner::TaskSpawner,
59        types::BoxAsyncUnaryResultHandler,
60        utils::task::abort_locked_task,
61    },
62    namespace::WsIoServerNamespace,
63};
64
65// Enums
66#[repr(u8)]
67#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
68enum ConnectionStatus {
69    Activating,
70    Authenticating,
71    AwaitingAuth,
72    Closed,
73    Closing,
74    Created,
75    Ready,
76}
77
78// Structs
79pub struct WsIoServerConnection {
80    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
81    cancel_token: ArcSwap<CancellationToken>,
82    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
83    #[cfg(feature = "connection-extensions")]
84    extensions: ConnectionExtensions,
85    headers: HeaderMap,
86    id: u64,
87    message_tx: Sender<Message>,
88    namespace: Arc<WsIoServerNamespace>,
89    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
90    status: AtomicStatus<ConnectionStatus>,
91}
92
93impl TaskSpawner for WsIoServerConnection {
94    #[inline]
95    fn cancel_token(&self) -> Arc<CancellationToken> {
96        self.cancel_token.load_full()
97    }
98}
99
100impl WsIoServerConnection {
101    #[inline]
102    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Arc<Self>, Receiver<Message>) {
103        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
104        let (message_tx, message_rx) = channel(channel_capacity);
105        (
106            Arc::new(Self {
107                auth_timeout_task: Mutex::new(None),
108                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
109                event_registry: WsIoEventRegistry::new(),
110                #[cfg(feature = "connection-extensions")]
111                extensions: ConnectionExtensions::new(),
112                headers,
113                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
114                message_tx,
115                namespace,
116                on_close_handler: Mutex::new(None),
117                status: AtomicStatus::new(ConnectionStatus::Created),
118            }),
119            message_rx,
120        )
121    }
122
123    // Private methods
124    async fn activate(self: &Arc<Self>) -> Result<()> {
125        // Verify current state; only valid from Authenticating or Created → Activating
126        let status = self.status.get();
127        match status {
128            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
129                self.status.try_transition(status, ConnectionStatus::Activating)?
130            }
131            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
132        }
133
134        // Invoke middleware with timeout protection if configured
135        if let Some(middleware) = &self.namespace.config.middleware {
136            timeout(
137                self.namespace.config.middleware_execution_timeout,
138                middleware(self.clone()),
139            )
140            .await??;
141        }
142
143        // Invoke on_connect handler with timeout protection if configured
144        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
145            timeout(
146                self.namespace.config.on_connect_handler_timeout,
147                on_connect_handler(self.clone()),
148            )
149            .await??;
150        }
151
152        // Insert connection into namespace
153        self.namespace.insert_connection(self.clone());
154
155        // Transition state to Ready
156        self.status
157            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
158
159        // Send ready packet
160        self.send_packet(&WsIoPacket::new_ready()).await?;
161
162        // Invoke on_ready handler if configured
163        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
164            // Run handler asynchronously in a detached task
165            self.spawn_task(on_ready_handler(self.clone()));
166        }
167
168        Ok(())
169    }
170
171    #[inline]
172    fn ensure_status_ready(&self) -> Result<()> {
173        self.status.ensure(ConnectionStatus::Ready, |status| {
174            format!("Cannot emit event in invalid status: {:#?}", status)
175        })?;
176
177        Ok(())
178    }
179
180    async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
181        // Verify current state; only valid from AwaitingAuth → Authenticating
182        let status = self.status.get();
183        match status {
184            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
185            _ => bail!("Received auth packet in invalid status: {:#?}", status),
186        }
187
188        // Abort auth-timeout task if still active
189        abort_locked_task(&self.auth_timeout_task).await;
190
191        // Invoke auth handler with timeout protection if configured, otherwise raise error
192        if let Some(auth_handler) = &self.namespace.config.auth_handler {
193            timeout(
194                self.namespace.config.auth_handler_timeout,
195                auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
196            )
197            .await??;
198
199            // Activate connection
200            self.activate().await
201        } else {
202            bail!("Auth packet received but no auth handler is configured");
203        }
204    }
205
206    #[inline]
207    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
208        self.event_registry.dispatch_event_packet(
209            self.clone(),
210            event,
211            &self.namespace.config.packet_codec,
212            packet_data,
213            self,
214        );
215
216        Ok(())
217    }
218
219    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
220        Ok(self
221            .message_tx
222            .send(self.namespace.encode_packet_to_message(packet)?)
223            .await?)
224    }
225
226    // Protected methods
227    pub(crate) async fn cleanup(self: &Arc<Self>) {
228        // Set connection state to Closing
229        self.status.store(ConnectionStatus::Closing);
230
231        // Abort auth-timeout task if still active
232        abort_locked_task(&self.auth_timeout_task).await;
233
234        // Remove connection from namespace
235        self.namespace.remove_connection(self.id);
236
237        // Cancel all ongoing operations via cancel token
238        self.cancel_token.load().cancel();
239
240        // Invoke on_close handler with timeout protection if configured
241        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
242            let _ = timeout(
243                self.namespace.config.on_close_handler_timeout,
244                on_close_handler(self.clone()),
245            )
246            .await;
247        }
248
249        // Set connection state to Closed
250        self.status.store(ConnectionStatus::Closed);
251    }
252
253    #[inline]
254    pub(crate) fn close(&self) {
255        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
256        match self.status.get() {
257            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
258            _ => self.status.store(ConnectionStatus::Closing),
259        }
260
261        // Send websocket close frame to initiate graceful shutdown
262        let _ = self.message_tx.try_send(Message::Close(None));
263    }
264
265    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
266        // TODO: lazy load
267        let packet = self.namespace.config.packet_codec.decode(bytes)?;
268        match packet.r#type {
269            WsIoPacketType::Auth => {
270                if let Some(packet_data) = packet.data.as_deref() {
271                    self.handle_auth_packet(packet_data).await
272                } else {
273                    bail!("Auth packet missing data");
274                }
275            }
276            WsIoPacketType::Event => {
277                if let Some(event) = packet.key.as_deref() {
278                    self.handle_event_packet(event, packet.data)
279                } else {
280                    bail!("Event packet missing key");
281                }
282            }
283            _ => Ok(()),
284        }
285    }
286
287    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
288        // Verify current state; only valid Created
289        self.status.ensure(ConnectionStatus::Created, |status| {
290            format!("Cannot init connection in invalid status: {:#?}", status)
291        })?;
292
293        // Determine if authentication is required
294        let requires_auth = self.namespace.config.auth_handler.is_some();
295
296        // Build Init packet to inform client whether auth is required
297        let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
298
299        // If authentication is required
300        if requires_auth {
301            // Transition state to AwaitingAuth
302            self.status
303                .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
304
305            // Spawn auth-packet-timeout watchdog to close connection if auth not received in time
306            let connection = self.clone();
307            *self.auth_timeout_task.lock().await = Some(spawn(async move {
308                sleep(connection.namespace.config.auth_packet_timeout).await;
309                if connection.status.is(ConnectionStatus::AwaitingAuth) {
310                    connection.close();
311                }
312            }));
313
314            // Send Init packet to client (expecting auth response)
315            self.send_packet(packet).await
316        } else {
317            // Send Init packet to client (no auth required)
318            self.send_packet(packet).await?;
319
320            // Immediately activate connection
321            self.activate().await
322        }
323    }
324
325    // Public methods
326    pub async fn disconnect(&self) {
327        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
328        self.close()
329    }
330
331    pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
332        self.ensure_status_ready()?;
333        self.send_packet(&WsIoPacket::new_event(
334            event,
335            data.map(|data| self.namespace.config.packet_codec.encode_data(data))
336                .transpose()?,
337        ))
338        .await
339    }
340
341    pub async fn emit_message(&self, message: Message) -> Result<()> {
342        self.ensure_status_ready()?;
343        Ok(self.message_tx.send(message).await?)
344    }
345
346    #[cfg(feature = "connection-extensions")]
347    #[inline]
348    pub fn extensions(&self) -> &ConnectionExtensions {
349        &self.extensions
350    }
351
352    #[inline]
353    pub fn headers(&self) -> &HeaderMap {
354        &self.headers
355    }
356
357    #[inline]
358    pub fn id(&self) -> u64 {
359        self.id
360    }
361
362    #[inline]
363    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
364        self.namespace.clone()
365    }
366
367    #[inline]
368    pub fn off(&self, event: impl AsRef<str>) {
369        self.event_registry.off(event);
370    }
371
372    #[inline]
373    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
374        self.event_registry.off_by_handler_id(event, handler_id);
375    }
376
377    #[inline]
378    pub fn on<H, Fut, D>(&self, event: impl Into<String>, handler: H) -> u32
379    where
380        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
381        Fut: Future<Output = Result<()>> + Send + 'static,
382        D: DeserializeOwned + Send + Sync + 'static,
383    {
384        self.event_registry.on(event, handler)
385    }
386
387    pub async fn on_close<H, Fut>(&self, handler: H)
388    where
389        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
390        Fut: Future<Output = Result<()>> + Send + 'static,
391    {
392        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
393    }
394
395    #[inline]
396    pub fn server(&self) -> WsIoServer {
397        self.namespace.server()
398    }
399}
400
401// Constants/Statics
402static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));