rust_engineio/asynchronous/async_transports/
websocket_secure.rs

1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use crate::asynchronous::transport::AsyncTransport;
6use crate::error::Result;
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures_util::Stream;
10use futures_util::StreamExt;
11use http::HeaderMap;
12use native_tls::TlsConnector;
13use tokio::sync::RwLock;
14use tokio_tungstenite::connect_async_tls_with_config;
15use tokio_tungstenite::Connector;
16use tungstenite::client::IntoClientRequest;
17use url::Url;
18
19use super::websocket_general::AsyncWebsocketGeneralTransport;
20
21/// An asynchronous websocket transport type.
22/// This type only allows for secure websocket
23/// connections ("wss://").
24#[derive(Clone)]
25pub struct WebsocketSecureTransport {
26    inner: AsyncWebsocketGeneralTransport,
27    base_url: Arc<RwLock<Url>>,
28}
29
30impl WebsocketSecureTransport {
31    /// Creates a new instance over a request that might hold additional headers, a possible
32    /// Tls connector and an URL.
33    pub(crate) async fn new(
34        base_url: Url,
35        tls_config: Option<TlsConnector>,
36        headers: Option<HeaderMap>,
37    ) -> Result<Self> {
38        let mut url = base_url;
39        url.query_pairs_mut().append_pair("transport", "websocket");
40        url.set_scheme("wss").unwrap();
41
42        let mut req = url.clone().into_client_request()?;
43        if let Some(map) = headers {
44            // SAFETY: this unwrap never panics as the underlying request is just initialized and in proper state
45            req.headers_mut().extend(map);
46        }
47
48        // `disable_nagle` Sets the value of the TCP_NODELAY option on this socket.
49        //
50        // If set to `true`, this option disables the Nagle algorithm.
51        // This means that segments are always sent as soon as possible, even if there is only a small amount of data.
52        // When `false`, data is buffered until there is a sufficient amount to send out, thereby avoiding the frequent sending of small packets.
53        //
54        // See the docs: https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_nodelay
55        let (ws_stream, _) = connect_async_tls_with_config(
56            req,
57            None,
58            /*disable_nagle=*/ false,
59            tls_config.map(Connector::NativeTls),
60        )
61        .await?;
62
63        let (sen, rec) = ws_stream.split();
64        let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
65
66        Ok(WebsocketSecureTransport {
67            inner,
68            base_url: Arc::new(RwLock::new(url)),
69        })
70    }
71
72    /// Sends probe packet to ensure connection is valid, then sends upgrade
73    /// request
74    pub(crate) async fn upgrade(&self) -> Result<()> {
75        self.inner.upgrade().await
76    }
77
78    pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
79        self.inner.poll_next().await
80    }
81}
82
83impl Stream for WebsocketSecureTransport {
84    type Item = Result<Bytes>;
85
86    fn poll_next(
87        mut self: Pin<&mut Self>,
88        cx: &mut std::task::Context<'_>,
89    ) -> std::task::Poll<Option<Self::Item>> {
90        self.inner.poll_next_unpin(cx)
91    }
92}
93
94#[async_trait]
95impl AsyncTransport for WebsocketSecureTransport {
96    async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
97        self.inner.emit(data, is_binary_att).await
98    }
99
100    async fn base_url(&self) -> Result<Url> {
101        Ok(self.base_url.read().await.clone())
102    }
103
104    async fn set_base_url(&self, base_url: Url) -> Result<()> {
105        let mut url = base_url;
106        if !url
107            .query_pairs()
108            .any(|(k, v)| k == "transport" && v == "websocket")
109        {
110            url.query_pairs_mut().append_pair("transport", "websocket");
111        }
112        url.set_scheme("wss").unwrap();
113        *self.base_url.write().await = url;
114        Ok(())
115    }
116}
117
118impl Debug for WebsocketSecureTransport {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("AsyncWebsocketSecureTransport")
121            .field(
122                "base_url",
123                &self
124                    .base_url
125                    .try_read()
126                    .map_or("Currently not available".to_owned(), |url| url.to_string()),
127            )
128            .finish()
129    }
130}
131
132#[cfg(test)]
133mod test {
134    use super::*;
135    use crate::ENGINE_IO_VERSION;
136    use std::str::FromStr;
137
138    async fn new() -> Result<WebsocketSecureTransport> {
139        let url = crate::test::engine_io_server_secure()?.to_string()
140            + "engine.io/?EIO="
141            + &ENGINE_IO_VERSION.to_string();
142        WebsocketSecureTransport::new(
143            Url::from_str(&url[..])?,
144            Some(crate::test::tls_connector()?),
145            None,
146        )
147        .await
148    }
149
150    #[tokio::test]
151    async fn websocket_secure_transport_base_url() -> Result<()> {
152        let transport = new().await?;
153        let mut url = crate::test::engine_io_server_secure()?;
154        url.set_path("/engine.io/");
155        url.query_pairs_mut()
156            .append_pair("EIO", &ENGINE_IO_VERSION.to_string())
157            .append_pair("transport", "websocket");
158        url.set_scheme("wss").unwrap();
159        assert_eq!(transport.base_url().await?.to_string(), url.to_string());
160        transport
161            .set_base_url(reqwest::Url::parse("https://127.0.0.1")?)
162            .await?;
163        assert_eq!(
164            transport.base_url().await?.to_string(),
165            "wss://127.0.0.1/?transport=websocket"
166        );
167        assert_ne!(transport.base_url().await?.to_string(), url.to_string());
168
169        transport
170            .set_base_url(reqwest::Url::parse(
171                "http://127.0.0.1/?transport=websocket",
172            )?)
173            .await?;
174        assert_eq!(
175            transport.base_url().await?.to_string(),
176            "wss://127.0.0.1/?transport=websocket"
177        );
178        assert_ne!(transport.base_url().await?.to_string(), url.to_string());
179        Ok(())
180    }
181
182    #[tokio::test]
183    async fn websocket_secure_debug() -> Result<()> {
184        let transport = new().await?;
185        assert_eq!(
186            format!("{:?}", transport),
187            format!(
188                "AsyncWebsocketSecureTransport {{ base_url: {:?} }}",
189                transport.base_url().await?.to_string()
190            )
191        );
192        Ok(())
193    }
194}