Skip to main content

wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16    HeaderMap,
17    Uri,
18};
19use kikiutils::{
20    atomic::enum_cell::AtomicEnumCell,
21    types::fx_collections::FxDashSet,
22};
23use num_enum::{
24    IntoPrimitive,
25    TryFromPrimitive,
26};
27use serde::{
28    Serialize,
29    de::DeserializeOwned,
30};
31use tokio::{
32    spawn,
33    sync::{
34        Mutex,
35        mpsc::{
36            Receiver,
37            Sender,
38            channel,
39        },
40    },
41    task::JoinHandle,
42    time::{
43        sleep,
44        timeout,
45    },
46};
47use tokio_tungstenite::tungstenite::Message;
48use tokio_util::sync::CancellationToken;
49
50#[cfg(feature = "connection-extensions")]
51mod extensions;
52
53#[cfg(feature = "connection-extensions")]
54use self::extensions::ConnectionExtensions;
55use crate::{
56    WsIoServer,
57    core::{
58        channel_capacity_from_websocket_config,
59        event::registry::WsIoEventRegistry,
60        packet::{
61            WsIoPacket,
62            WsIoPacketType,
63        },
64        traits::task::spawner::TaskSpawner,
65        types::BoxAsyncUnaryResultHandler,
66        utils::task::abort_locked_task,
67    },
68    namespace::{
69        WsIoServerNamespace,
70        operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71    },
72};
73
74// Enums
75#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionState {
78    Activating,
79    AwaitingInit,
80    Closed,
81    Closing,
82    Created,
83    Initiating,
84    Ready,
85}
86
87// Structs
88pub struct WsIoServerConnection {
89    cancel_token: ArcSwap<CancellationToken>,
90    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91    #[cfg(feature = "connection-extensions")]
92    extensions: ConnectionExtensions,
93    headers: HeaderMap,
94    id: u64,
95    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96    joined_rooms: FxDashSet<String>,
97    message_tx: Sender<Arc<Message>>,
98    namespace: Arc<WsIoServerNamespace>,
99    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100    request_uri: Uri,
101    state: AtomicEnumCell<ConnectionState>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105    #[inline]
106    fn cancel_token(&self) -> Arc<CancellationToken> {
107        self.cancel_token.load_full()
108    }
109}
110
111impl WsIoServerConnection {
112    #[inline]
113    pub(crate) fn new(
114        headers: HeaderMap,
115        namespace: Arc<WsIoServerNamespace>,
116        request_uri: Uri,
117    ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119        let (message_tx, message_rx) = channel(channel_capacity);
120        (
121            Arc::new(Self {
122                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123                event_registry: WsIoEventRegistry::new(),
124                #[cfg(feature = "connection-extensions")]
125                extensions: ConnectionExtensions::new(),
126                headers,
127                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128                init_timeout_task: Mutex::new(None),
129                joined_rooms: FxDashSet::default(),
130                message_tx,
131                namespace,
132                on_close_handler: Mutex::new(None),
133                request_uri,
134                state: AtomicEnumCell::new(ConnectionState::Created),
135            }),
136            message_rx,
137        )
138    }
139
140    // Private methods
141    #[inline]
142    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143        self.event_registry.dispatch_event_packet(
144            self.clone(),
145            event,
146            &self.namespace.config.packet_codec,
147            packet_data,
148            self,
149        );
150
151        Ok(())
152    }
153
154    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
155        // Verify current state; only valid from AwaitingInit → Initiating
156        let state = self.state.get();
157        match state {
158            ConnectionState::AwaitingInit => self.state.try_transition(state, ConnectionState::Initiating)?,
159            _ => bail!("Received init packet in invalid state: {state:?}"),
160        }
161
162        // Abort init-timeout task
163        abort_locked_task(&self.init_timeout_task).await;
164
165        // Invoke init_response_handler with timeout protection if configured
166        if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
167            timeout(
168                self.namespace.config.init_response_handler_timeout,
169                init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
170            )
171            .await??
172        }
173
174        // Activate connection
175        self.state
176            .try_transition(ConnectionState::Initiating, ConnectionState::Activating)?;
177
178        // Invoke middleware with timeout protection if configured
179        if let Some(middleware) = &self.namespace.config.middleware {
180            timeout(
181                self.namespace.config.middleware_execution_timeout,
182                middleware(self.clone()),
183            )
184            .await??;
185
186            // Ensure connection is still in Activating state
187            self.state.ensure(ConnectionState::Activating, |state| {
188                format!("Cannot activate connection in invalid state: {state:?}")
189            })?;
190        }
191
192        // Invoke on_connect_handler with timeout protection if configured
193        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
194            timeout(
195                self.namespace.config.on_connect_handler_timeout,
196                on_connect_handler(self.clone()),
197            )
198            .await??;
199        }
200
201        // Transition state to Ready
202        self.state
203            .try_transition(ConnectionState::Activating, ConnectionState::Ready)?;
204
205        // Insert connection into namespace
206        self.namespace.insert_connection(self.clone());
207
208        // Send ready packet
209        self.send_packet(&WsIoPacket::new_ready()).await?;
210
211        // Invoke on_ready_handler if configured
212        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
213            // Run handler asynchronously in a detached task
214            self.spawn_task(on_ready_handler(self.clone()));
215        }
216
217        Ok(())
218    }
219
220    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221        self.send_message(self.namespace.encode_packet_to_message(packet)?)
222            .await
223    }
224
225    // Protected methods
226    pub(crate) async fn cleanup(self: &Arc<Self>) {
227        // Set connection state to Closing
228        self.state.store(ConnectionState::Closing);
229
230        // Remove connection from namespace
231        self.namespace.remove_connection(self.id);
232
233        // Leave all joined rooms
234        let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235        for room_name in &joined_rooms {
236            self.namespace.remove_connection_id_from_room(room_name, self.id);
237        }
238
239        self.joined_rooms.clear();
240
241        // Abort init-timeout task
242        abort_locked_task(&self.init_timeout_task).await;
243
244        // Cancel all ongoing operations via cancel token
245        self.cancel_token.load().cancel();
246
247        // Invoke on_close_handler with timeout protection if configured
248        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249            let _ = timeout(
250                self.namespace.config.on_close_handler_timeout,
251                on_close_handler(self.clone()),
252            )
253            .await;
254        }
255
256        // Set connection state to Closed
257        self.state.store(ConnectionState::Closed);
258    }
259
260    #[inline]
261    pub(crate) fn close(&self) {
262        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
263        match self.state.get() {
264            ConnectionState::Closed | ConnectionState::Closing => return,
265            _ => self.state.store(ConnectionState::Closing),
266        }
267
268        // Send websocket close frame to initiate graceful shutdown
269        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270    }
271
272    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273        self.state.ensure(ConnectionState::Ready, |state| {
274            format!("Cannot emit in invalid state: {state:?}")
275        })?;
276
277        self.send_message(message).await
278    }
279
280    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
281        // TODO: lazy load
282        let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
283        match packet.r#type {
284            WsIoPacketType::Event => {
285                if self.is_ready() {
286                    if let Some(event) = packet.key.as_deref() {
287                        return self.handle_event_packet(event, packet.data);
288                    } else {
289                        bail!("Event packet missing key");
290                    }
291                }
292
293                Ok(())
294            }
295            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
296            _ => Ok(()),
297        }
298    }
299
300    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
301        // Verify current state; only valid Created
302        self.state.ensure(ConnectionState::Created, |state| {
303            format!("Cannot init connection in invalid state: {state:?}")
304        })?;
305
306        // Generate init request data if init request handler is configured
307        let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
308            timeout(
309                self.namespace.config.init_request_handler_timeout,
310                init_request_handler(self.clone(), &self.namespace.config.packet_codec),
311            )
312            .await??
313        } else {
314            None
315        };
316
317        // Transition state to AwaitingInit
318        self.state
319            .try_transition(ConnectionState::Created, ConnectionState::AwaitingInit)?;
320
321        // Spawn init-response-timeout watchdog to close connection if init not received in time
322        let connection = self.clone();
323        *self.init_timeout_task.lock().await = Some(spawn(async move {
324            sleep(connection.namespace.config.init_response_timeout).await;
325            if connection.state.is(ConnectionState::AwaitingInit) {
326                connection.close();
327            }
328        }));
329
330        // Send init packet
331        self.send_packet(&WsIoPacket::new_init(init_request_data)).await
332    }
333
334    pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
335        Ok(self.message_tx.send(message).await?)
336    }
337
338    // Public methods
339    pub async fn disconnect(&self) {
340        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
341        self.close()
342    }
343
344    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
345        self.emit_event_message(
346            self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
347                event.as_ref(),
348                data.map(|data| self.namespace.config.packet_codec.encode_data(data))
349                    .transpose()?,
350            ))?,
351        )
352        .await
353    }
354
355    #[inline]
356    pub fn except(
357        self: &Arc<Self>,
358        room_names: impl IntoIterator<Item = impl Into<String>>,
359    ) -> WsIoServerNamespaceBroadcastOperator {
360        self.namespace.except(room_names).except_connection_ids([self.id])
361    }
362
363    #[cfg(feature = "connection-extensions")]
364    #[inline]
365    pub fn extensions(&self) -> &ConnectionExtensions {
366        &self.extensions
367    }
368
369    #[inline]
370    pub fn headers(&self) -> &HeaderMap {
371        &self.headers
372    }
373
374    #[inline]
375    pub fn id(&self) -> u64 {
376        self.id
377    }
378
379    #[inline]
380    pub fn is_ready(&self) -> bool {
381        self.state.is(ConnectionState::Ready)
382    }
383
384    #[inline]
385    pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl Into<String>>) {
386        for room_name in room_names {
387            let room_name = room_name.into();
388            self.namespace.add_connection_id_to_room(&room_name, self.id);
389            self.joined_rooms.insert(room_name);
390        }
391    }
392
393    #[inline]
394    pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl Into<String>>) {
395        for room_name in room_names {
396            let room_name = &room_name.into();
397            self.namespace.remove_connection_id_from_room(room_name, self.id);
398
399            self.joined_rooms.remove(room_name);
400        }
401    }
402
403    #[inline]
404    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
405        self.namespace.clone()
406    }
407
408    #[inline]
409    pub fn off(&self, event: impl AsRef<str>) {
410        self.event_registry.off(event.as_ref());
411    }
412
413    #[inline]
414    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
415        self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
416    }
417
418    #[inline]
419    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
420    where
421        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
422        Fut: Future<Output = Result<()>> + Send + 'static,
423        D: DeserializeOwned + Send + Sync + 'static,
424    {
425        self.event_registry.on(event.as_ref(), handler)
426    }
427
428    pub async fn on_close<H, Fut>(&self, handler: H)
429    where
430        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
431        Fut: Future<Output = Result<()>> + Send + 'static,
432    {
433        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
434    }
435
436    #[inline]
437    pub fn request_uri(&self) -> &Uri {
438        &self.request_uri
439    }
440
441    #[inline]
442    pub fn server(&self) -> WsIoServer {
443        self.namespace.server()
444    }
445
446    #[inline]
447    pub fn to(
448        self: &Arc<Self>,
449        room_names: impl IntoIterator<Item = impl Into<String>>,
450    ) -> WsIoServerNamespaceBroadcastOperator {
451        self.namespace.to(room_names).except_connection_ids([self.id])
452    }
453}
454
455// Constants/Statics
456static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));
457
458#[cfg(test)]
459mod tests {
460    use http::{
461        HeaderMap,
462        Uri,
463    };
464
465    use super::*;
466
467    async fn create_test_connection() -> Arc<WsIoServerConnection> {
468        let server = Arc::new(WsIoServer::builder().build());
469        let namespace = server.new_namespace_builder("/socket").register().unwrap();
470        let (connection, _rx) =
471            WsIoServerConnection::new(HeaderMap::new(), namespace, Uri::from_static("http://localhost"));
472
473        connection
474    }
475
476    #[tokio::test]
477    async fn test_handle_incoming_packet_decode_error() {
478        let connection = create_test_connection().await;
479        let garbage_data = b"obviously not valid json or messagepack";
480        // Should seamlessly return a Result::Err, not panic
481        let result = connection.handle_incoming_packet(garbage_data).await;
482        assert!(result.is_err(), "Decoding garbage payload should trigger an error");
483    }
484
485    #[tokio::test]
486    async fn test_handle_init_packet_in_invalid_state() {
487        let connection = create_test_connection().await;
488        assert_eq!(connection.state.get(), ConnectionState::Created);
489
490        // Sending an init packet when the connection is merely `Created` (not yet `AwaitingInit`) should throw an error
491        // Init packet JSON encoded (type: 2 = Init) -> serialized as tuple array
492        let encoded = b"[2,null,null]";
493
494        // This simulates a manual client Init push before server starts the handshake buffer
495        let result = connection.handle_incoming_packet(encoded).await;
496        assert!(
497            result.is_err(),
498            "Should error because state is Created, not AwaitingInit"
499        );
500
501        assert!(result.unwrap_err().to_string().contains("invalid state"));
502    }
503
504    #[tokio::test]
505    async fn test_handle_event_packet_missing_key() {
506        let connection = create_test_connection().await;
507
508        // Force the connection into the Ready state so it accepts Event packets
509        connection.state.store(ConnectionState::Ready);
510
511        // Manufacture an Event packet manually without a key (type: 1 = Event) -> serialized as tuple array
512        let encoded = b"[1,null,null]";
513
514        let result = connection.handle_incoming_packet(encoded).await;
515        assert!(result.is_err(), "Should bail on missing event key");
516        assert_eq!(result.unwrap_err().to_string(), "Event packet missing key");
517    }
518
519    #[tokio::test]
520    async fn test_connection_close_state_transitions() {
521        let connection = create_test_connection().await;
522        assert_eq!(connection.state.get(), ConnectionState::Created);
523
524        connection.close();
525        assert_eq!(connection.state.get(), ConnectionState::Closing);
526
527        // Calling close again when Closing shouldn't alter anything
528        connection.close();
529        assert_eq!(connection.state.get(), ConnectionState::Closing);
530    }
531
532    #[tokio::test]
533    async fn test_connection_cleanup() {
534        let connection = create_test_connection().await;
535        let namespace = connection.namespace();
536
537        // Insert connection manually for test
538        namespace.insert_connection(connection.clone());
539        assert_eq!(namespace.connection_count(), 1);
540
541        connection.join(["room_a", "room_b"]);
542        assert!(connection.joined_rooms.contains("room_a"));
543
544        connection.cleanup().await;
545
546        assert_eq!(connection.state.get(), ConnectionState::Closed);
547        assert!(connection.joined_rooms.is_empty());
548        assert_eq!(namespace.connection_count(), 0);
549    }
550}