wsio_server/
connection.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use bson::oid::ObjectId;
5use http::HeaderMap;
6use tokio::{
7    spawn,
8    sync::{
9        Mutex,
10        RwLock,
11        mpsc::{
12            UnboundedReceiver,
13            UnboundedSender,
14            unbounded_channel,
15        },
16    },
17    task::JoinHandle,
18    time::sleep,
19};
20use tokio_tungstenite::tungstenite::Message;
21
22use crate::{
23    core::packet::{
24        WsIoPacket,
25        WsIoPacketType,
26    },
27    namespace::WsIoServerNamespace,
28    types::handler::WsIoServerConnectionOnDisconnectHandler,
29};
30
31enum WsIoServerConnectionStatus {
32    Activating,
33    AwaitingAuth,
34    Closed,
35    Closing,
36    Created,
37    Ready,
38}
39
40pub struct WsIoServerConnection {
41    headers: HeaderMap,
42    namespace: Arc<WsIoServerNamespace>,
43    on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
44    sid: String,
45    status: RwLock<WsIoServerConnectionStatus>,
46    tx: UnboundedSender<Message>,
47    wait_auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
48}
49
50impl WsIoServerConnection {
51    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
52        let (tx, rx) = unbounded_channel();
53        (
54            Self {
55                headers,
56                namespace,
57                on_disconnect_handler: Mutex::new(None),
58                sid: ObjectId::new().to_string(),
59                status: RwLock::new(WsIoServerConnectionStatus::Created),
60                tx,
61                wait_auth_timeout_task: Mutex::new(None),
62            },
63            rx,
64        )
65    }
66
67    // Protected methods
68    pub(crate) async fn activate(self: &Arc<Self>) -> Result<()> {
69        *self.status.write().await = WsIoServerConnectionStatus::Activating;
70        // TODO: middlewares
71        self.namespace.insert_connection(self.clone());
72        *self.status.write().await = WsIoServerConnectionStatus::Ready;
73        let packet = WsIoPacket {
74            data: None,
75            key: None,
76            r#type: WsIoPacketType::Ready,
77        };
78
79        self.send_packet(&packet)?;
80        (self.namespace.config.on_connect_handler)(self.clone()).await
81    }
82
83    pub(crate) async fn cleanup(self: &Arc<Self>) {
84        *self.status.write().await = WsIoServerConnectionStatus::Closing;
85        if let Some(wait_auth_timeout_task) = self.wait_auth_timeout_task.lock().await.take() {
86            wait_auth_timeout_task.abort();
87        }
88
89        self.namespace.cleanup_connection(&self.sid);
90        if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
91            let _ = on_disconnect_handler(self.clone()).await;
92        }
93
94        *self.status.write().await = WsIoServerConnectionStatus::Closed;
95    }
96
97    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
98        self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
99        let require_auth = self.namespace.config.auth_handler.is_some();
100        let packet = WsIoPacket {
101            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
102            key: Some(self.sid.clone()),
103            r#type: WsIoPacketType::Init,
104        };
105
106        if require_auth {
107            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
108            let connection = self.clone();
109            self.wait_auth_timeout_task.lock().await.replace(spawn(async move {
110                sleep(connection.namespace.config.auth_timeout).await;
111                if matches!(
112                    *connection.status.read().await,
113                    WsIoServerConnectionStatus::AwaitingAuth
114                ) {
115                    connection.close();
116                }
117            }));
118
119            self.send_packet(&packet)?;
120        } else {
121            self.send_packet(&packet)?;
122            self.activate().await?;
123        }
124
125        Ok(())
126    }
127
128    pub(crate) async fn on_message(&self, _message: Message) {}
129
130    #[inline]
131    pub(crate) fn send(&self, message: Message) -> Result<()> {
132        Ok(self.tx.send(message)?)
133    }
134
135    #[inline]
136    pub(crate) fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
137        self.send(self.namespace.encode_packet_to_message(packet)?)
138    }
139
140    // Public methods
141
142    #[inline]
143    pub fn close(&self) {
144        let _ = self.send(Message::Close(None));
145    }
146
147    #[inline]
148    pub fn headers(&self) -> &HeaderMap {
149        &self.headers
150    }
151
152    #[inline]
153    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
154        self.namespace.clone()
155    }
156
157    pub async fn on_disconnect<H, Fut>(&self, handler: H)
158    where
159        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
160        Fut: Future<Output = Result<()>> + Send + 'static,
161    {
162        self.on_disconnect_handler
163            .lock()
164            .await
165            .replace(Box::new(move |connection| Box::pin(handler(connection))));
166    }
167
168    #[inline]
169    pub fn sid(&self) -> &str {
170        &self.sid
171    }
172}