Skip to main content

reqwest_websocket/
native.rs

1use std::borrow::Cow;
2
3use crate::{
4    protocol::{CloseCode, Message},
5    Client, Error, RequestBuilder,
6};
7use reqwest::{
8    header::{HeaderName, HeaderValue},
9    Response, StatusCode, Version,
10};
11use tungstenite::protocol::WebSocketConfig;
12
13pub async fn send_request<R>(
14    request_builder: R,
15    protocols: &[String],
16) -> Result<WebSocketResponse, Error>
17where
18    R: RequestBuilder,
19{
20    let (client, request_result) = request_builder.build_split();
21    let mut request = request_result?;
22
23    // change the scheme from wss? to https?
24    let url = request.url_mut();
25    match url.scheme() {
26        "ws" => {
27            url.set_scheme("http")
28                .expect("url should accept http scheme");
29        }
30        "wss" => {
31            url.set_scheme("https")
32                .expect("url should accept https scheme");
33        }
34        _ => {}
35    }
36
37    // prepare request
38    let version = request.version();
39    let nonce = match version {
40        Version::HTTP_10 | Version::HTTP_11 => {
41            // HTTP 1 requires us to set some headers.
42            let nonce_value = tungstenite::handshake::client::generate_key();
43            let headers = request.headers_mut();
44            headers.insert(
45                reqwest::header::CONNECTION,
46                HeaderValue::from_static("upgrade"),
47            );
48            headers.insert(
49                reqwest::header::UPGRADE,
50                HeaderValue::from_static("websocket"),
51            );
52            headers.insert(
53                reqwest::header::SEC_WEBSOCKET_KEY,
54                HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
55            );
56            headers.insert(
57                reqwest::header::SEC_WEBSOCKET_VERSION,
58                HeaderValue::from_static("13"),
59            );
60            if !protocols.is_empty() {
61                headers.insert(
62                    reqwest::header::SEC_WEBSOCKET_PROTOCOL,
63                    HeaderValue::from_str(&protocols.join(", "))
64                        .expect("protocols is an invalid header value"),
65                );
66            }
67
68            Some(nonce_value)
69        }
70        Version::HTTP_2 => {
71            // TODO: Implement websocket upgrade for HTTP 2.
72            return Err(HandshakeError::UnsupportedHttpVersion(version).into());
73        }
74        _ => {
75            return Err(HandshakeError::UnsupportedHttpVersion(version).into());
76        }
77    };
78
79    // execute request
80    let response = client.execute(request).await?;
81
82    Ok(WebSocketResponse {
83        response,
84        version,
85        nonce,
86    })
87}
88
89pub type WebSocketStream =
90    async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
91
92/// Error during `Websocket` handshake.
93#[derive(Debug, thiserror::Error)]
94pub enum HandshakeError {
95    #[error("unsupported http version: {0:?}")]
96    UnsupportedHttpVersion(Version),
97
98    #[error("the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2")]
99    ServerRespondedWithDifferentVersion,
100
101    #[error("missing header {header}")]
102    MissingHeader { header: HeaderName },
103
104    #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
105    UnexpectedHeaderValue {
106        header: HeaderName,
107        got: HeaderValue,
108        expected: Cow<'static, str>,
109    },
110
111    #[error("expected the server to select a protocol.")]
112    ExpectedAProtocol,
113
114    #[error("unexpected protocol: {got}")]
115    UnexpectedProtocol { got: String },
116
117    #[error("unexpected status code: {0}")]
118    UnexpectedStatusCode(StatusCode),
119}
120
121pub struct WebSocketResponse {
122    pub response: Response,
123    pub version: Version,
124    pub nonce: Option<String>,
125}
126
127impl WebSocketResponse {
128    pub async fn into_stream_and_protocol(
129        self,
130        protocols: Vec<String>,
131        web_socket_config: Option<WebSocketConfig>,
132    ) -> Result<(WebSocketStream, Option<String>), Error> {
133        let headers = self.response.headers();
134
135        if self.response.version() != self.version {
136            return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
137        }
138
139        if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
140            tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
141            return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
142        }
143
144        if let Some(header) = headers.get(reqwest::header::CONNECTION) {
145            if !header
146                .to_str()
147                .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
148            {
149                tracing::debug!("server responded with invalid Connection header: {header:?}");
150                return Err(HandshakeError::UnexpectedHeaderValue {
151                    header: reqwest::header::CONNECTION,
152                    got: header.clone(),
153                    expected: "upgrade".into(),
154                }
155                .into());
156            }
157        } else {
158            tracing::debug!("missing Connection header");
159            return Err(HandshakeError::MissingHeader {
160                header: reqwest::header::CONNECTION,
161            }
162            .into());
163        }
164
165        if let Some(header) = headers.get(reqwest::header::UPGRADE) {
166            if !header
167                .to_str()
168                .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
169            {
170                tracing::debug!("server responded with invalid Upgrade header: {header:?}");
171                return Err(HandshakeError::UnexpectedHeaderValue {
172                    header: reqwest::header::UPGRADE,
173                    got: header.clone(),
174                    expected: "websocket".into(),
175                }
176                .into());
177            }
178        } else {
179            tracing::debug!("missing Upgrade header");
180            return Err(HandshakeError::MissingHeader {
181                header: reqwest::header::UPGRADE,
182            }
183            .into());
184        }
185
186        if let Some(nonce) = &self.nonce {
187            let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
188
189            if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
190                if !header.to_str().is_ok_and(|s| s == expected_nonce) {
191                    tracing::debug!(
192                        "server responded with invalid Sec-Websocket-Accept header: {header:?}"
193                    );
194                    return Err(HandshakeError::UnexpectedHeaderValue {
195                        header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
196                        got: header.clone(),
197                        expected: expected_nonce.into(),
198                    }
199                    .into());
200                }
201            } else {
202                tracing::debug!("missing Sec-Websocket-Accept header");
203                return Err(HandshakeError::MissingHeader {
204                    header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
205                }
206                .into());
207            }
208        }
209
210        let protocol = headers
211            .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
212            .and_then(|v| v.to_str().ok())
213            .map(ToOwned::to_owned);
214
215        match (protocols.is_empty(), &protocol) {
216            (true, None) => {
217                // we didn't request any protocols, so we don't expect one
218                // in return
219            }
220            (false, None) => {
221                // server didn't reply with a protocol
222                return Err(HandshakeError::ExpectedAProtocol.into());
223            }
224            (false, Some(protocol)) => {
225                if !protocols.contains(protocol) {
226                    // the responded protocol is none which we requested
227                    return Err(HandshakeError::UnexpectedProtocol {
228                        got: protocol.clone(),
229                    }
230                    .into());
231                }
232            }
233            (true, Some(protocol)) => {
234                // we didn't request any protocols but got one anyway
235                return Err(HandshakeError::UnexpectedProtocol {
236                    got: protocol.clone(),
237                }
238                .into());
239            }
240        }
241
242        use tokio_util::compat::TokioAsyncReadCompatExt;
243
244        let inner = WebSocketStream::from_raw_socket(
245            self.response.upgrade().await?.compat(),
246            tungstenite::protocol::Role::Client,
247            web_socket_config,
248        )
249        .await;
250
251        Ok((inner, protocol))
252    }
253}
254
255#[derive(Debug, thiserror::Error)]
256#[error("could not convert message")]
257pub struct FromTungsteniteMessageError {
258    pub original: tungstenite::Message,
259}
260
261impl TryFrom<tungstenite::Message> for Message {
262    type Error = FromTungsteniteMessageError;
263
264    fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
265        match value {
266            tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
267            tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
268            tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
269            tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
270            tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
271                code,
272                reason,
273            })) => Ok(Self::Close {
274                code: code.into(),
275                reason: reason.as_str().to_owned(),
276            }),
277            tungstenite::Message::Close(None) => Ok(Self::Close {
278                code: CloseCode::default(),
279                reason: "".to_owned(),
280            }),
281            tungstenite::Message::Frame(_) => Err(FromTungsteniteMessageError { original: value }),
282        }
283    }
284}
285
286impl From<Message> for tungstenite::Message {
287    fn from(value: Message) -> Self {
288        match value {
289            Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
290            Message::Binary(data) => Self::Binary(data),
291            Message::Ping(data) => Self::Ping(data),
292            Message::Pong(data) => Self::Pong(data),
293            Message::Close { code, reason } => {
294                Self::Close(Some(tungstenite::protocol::CloseFrame {
295                    code: code.into(),
296                    reason: reason.into(),
297                }))
298            }
299        }
300    }
301}
302
303impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
304    fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
305        u16::from(value).into()
306    }
307}
308
309impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
310    fn from(value: CloseCode) -> Self {
311        u16::from(value).into()
312    }
313}