rust_engineio/asynchronous/async_transports/
websocket_secure.rs1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use crate::asynchronous::transport::AsyncTransport;
6use crate::error::Result;
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures_util::Stream;
10use futures_util::StreamExt;
11use http::HeaderMap;
12use native_tls::TlsConnector;
13use tokio::sync::RwLock;
14use tokio_tungstenite::connect_async_tls_with_config;
15use tokio_tungstenite::Connector;
16use tungstenite::client::IntoClientRequest;
17use url::Url;
18
19use super::websocket_general::AsyncWebsocketGeneralTransport;
20
21#[derive(Clone)]
25pub struct WebsocketSecureTransport {
26 inner: AsyncWebsocketGeneralTransport,
27 base_url: Arc<RwLock<Url>>,
28}
29
30impl WebsocketSecureTransport {
31 pub(crate) async fn new(
34 base_url: Url,
35 tls_config: Option<TlsConnector>,
36 headers: Option<HeaderMap>,
37 ) -> Result<Self> {
38 let mut url = base_url;
39 url.query_pairs_mut().append_pair("transport", "websocket");
40 url.set_scheme("wss").unwrap();
41
42 let mut req = url.clone().into_client_request()?;
43 if let Some(map) = headers {
44 req.headers_mut().extend(map);
46 }
47
48 let (ws_stream, _) = connect_async_tls_with_config(
56 req,
57 None,
58 false,
59 tls_config.map(Connector::NativeTls),
60 )
61 .await?;
62
63 let (sen, rec) = ws_stream.split();
64 let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
65
66 Ok(WebsocketSecureTransport {
67 inner,
68 base_url: Arc::new(RwLock::new(url)),
69 })
70 }
71
72 pub(crate) async fn upgrade(&self) -> Result<()> {
75 self.inner.upgrade().await
76 }
77
78 pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
79 self.inner.poll_next().await
80 }
81}
82
83impl Stream for WebsocketSecureTransport {
84 type Item = Result<Bytes>;
85
86 fn poll_next(
87 mut self: Pin<&mut Self>,
88 cx: &mut std::task::Context<'_>,
89 ) -> std::task::Poll<Option<Self::Item>> {
90 self.inner.poll_next_unpin(cx)
91 }
92}
93
94#[async_trait]
95impl AsyncTransport for WebsocketSecureTransport {
96 async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
97 self.inner.emit(data, is_binary_att).await
98 }
99
100 async fn base_url(&self) -> Result<Url> {
101 Ok(self.base_url.read().await.clone())
102 }
103
104 async fn set_base_url(&self, base_url: Url) -> Result<()> {
105 let mut url = base_url;
106 if !url
107 .query_pairs()
108 .any(|(k, v)| k == "transport" && v == "websocket")
109 {
110 url.query_pairs_mut().append_pair("transport", "websocket");
111 }
112 url.set_scheme("wss").unwrap();
113 *self.base_url.write().await = url;
114 Ok(())
115 }
116}
117
118impl Debug for WebsocketSecureTransport {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("AsyncWebsocketSecureTransport")
121 .field(
122 "base_url",
123 &self
124 .base_url
125 .try_read()
126 .map_or("Currently not available".to_owned(), |url| url.to_string()),
127 )
128 .finish()
129 }
130}
131
132#[cfg(test)]
133mod test {
134 use super::*;
135 use crate::ENGINE_IO_VERSION;
136 use std::str::FromStr;
137
138 async fn new() -> Result<WebsocketSecureTransport> {
139 let url = crate::test::engine_io_server_secure()?.to_string()
140 + "engine.io/?EIO="
141 + &ENGINE_IO_VERSION.to_string();
142 WebsocketSecureTransport::new(
143 Url::from_str(&url[..])?,
144 Some(crate::test::tls_connector()?),
145 None,
146 )
147 .await
148 }
149
150 #[tokio::test]
151 async fn websocket_secure_transport_base_url() -> Result<()> {
152 let transport = new().await?;
153 let mut url = crate::test::engine_io_server_secure()?;
154 url.set_path("/engine.io/");
155 url.query_pairs_mut()
156 .append_pair("EIO", &ENGINE_IO_VERSION.to_string())
157 .append_pair("transport", "websocket");
158 url.set_scheme("wss").unwrap();
159 assert_eq!(transport.base_url().await?.to_string(), url.to_string());
160 transport
161 .set_base_url(reqwest::Url::parse("https://127.0.0.1")?)
162 .await?;
163 assert_eq!(
164 transport.base_url().await?.to_string(),
165 "wss://127.0.0.1/?transport=websocket"
166 );
167 assert_ne!(transport.base_url().await?.to_string(), url.to_string());
168
169 transport
170 .set_base_url(reqwest::Url::parse(
171 "http://127.0.0.1/?transport=websocket",
172 )?)
173 .await?;
174 assert_eq!(
175 transport.base_url().await?.to_string(),
176 "wss://127.0.0.1/?transport=websocket"
177 );
178 assert_ne!(transport.base_url().await?.to_string(), url.to_string());
179 Ok(())
180 }
181
182 #[tokio::test]
183 async fn websocket_secure_debug() -> Result<()> {
184 let transport = new().await?;
185 assert_eq!(
186 format!("{:?}", transport),
187 format!(
188 "AsyncWebsocketSecureTransport {{ base_url: {:?} }}",
189 transport.base_url().await?.to_string()
190 )
191 );
192 Ok(())
193 }
194}