reqwest_websocket/
native.rs

1use std::borrow::Cow;
2
3use crate::{
4    protocol::{CloseCode, Message},
5    Error,
6};
7use reqwest::{
8    header::{HeaderName, HeaderValue},
9    RequestBuilder, Response, StatusCode, Version,
10};
11use tungstenite::protocol::WebSocketConfig;
12
13pub async fn send_request(
14    request_builder: RequestBuilder,
15    protocols: &[String],
16) -> Result<WebSocketResponse, Error> {
17    let (client, request_result) = request_builder.build_split();
18    let mut request = request_result?;
19
20    // change the scheme from wss? to https?
21    let url = request.url_mut();
22    match url.scheme() {
23        "ws" => {
24            url.set_scheme("http")
25                .expect("url should accept http scheme");
26        }
27        "wss" => {
28            url.set_scheme("https")
29                .expect("url should accept https scheme");
30        }
31        _ => {}
32    }
33
34    // prepare request
35    let version = request.version();
36    let nonce = match version {
37        Version::HTTP_10 | Version::HTTP_11 => {
38            // HTTP 1 requires us to set some headers.
39            let nonce_value = tungstenite::handshake::client::generate_key();
40            let headers = request.headers_mut();
41            headers.insert(
42                reqwest::header::CONNECTION,
43                HeaderValue::from_static("upgrade"),
44            );
45            headers.insert(
46                reqwest::header::UPGRADE,
47                HeaderValue::from_static("websocket"),
48            );
49            headers.insert(
50                reqwest::header::SEC_WEBSOCKET_KEY,
51                HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
52            );
53            headers.insert(
54                reqwest::header::SEC_WEBSOCKET_VERSION,
55                HeaderValue::from_static("13"),
56            );
57            if !protocols.is_empty() {
58                headers.insert(
59                    reqwest::header::SEC_WEBSOCKET_PROTOCOL,
60                    HeaderValue::from_str(&protocols.join(", "))
61                        .expect("protocols is an invalid header value"),
62                );
63            }
64
65            Some(nonce_value)
66        }
67        Version::HTTP_2 => {
68            // TODO: Implement websocket upgrade for HTTP 2.
69            return Err(HandshakeError::UnsupportedHttpVersion(version).into());
70        }
71        _ => {
72            return Err(HandshakeError::UnsupportedHttpVersion(version).into());
73        }
74    };
75
76    // execute request
77    let response = client.execute(request).await?;
78
79    Ok(WebSocketResponse {
80        response,
81        version,
82        nonce,
83    })
84}
85
86pub type WebSocketStream =
87    async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
88
89/// Error during `Websocket` handshake.
90#[derive(Debug, thiserror::Error)]
91pub enum HandshakeError {
92    #[error("unsupported http version: {0:?}")]
93    UnsupportedHttpVersion(Version),
94
95    #[error("the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2")]
96    ServerRespondedWithDifferentVersion,
97
98    #[error("missing header {header}")]
99    MissingHeader { header: HeaderName },
100
101    #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
102    UnexpectedHeaderValue {
103        header: HeaderName,
104        got: HeaderValue,
105        expected: Cow<'static, str>,
106    },
107
108    #[error("expected the server to select a protocol.")]
109    ExpectedAProtocol,
110
111    #[error("unexpected protocol: {got}")]
112    UnexpectedProtocol { got: String },
113
114    #[error("unexpected status code: {0}")]
115    UnexpectedStatusCode(StatusCode),
116}
117
118pub struct WebSocketResponse {
119    pub response: Response,
120    pub version: Version,
121    pub nonce: Option<String>,
122}
123
124impl WebSocketResponse {
125    pub async fn into_stream_and_protocol(
126        self,
127        protocols: Vec<String>,
128        web_socket_config: Option<WebSocketConfig>,
129    ) -> Result<(WebSocketStream, Option<String>), Error> {
130        let headers = self.response.headers();
131
132        if self.response.version() != self.version {
133            return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
134        }
135
136        if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
137            tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
138            return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
139        }
140
141        if let Some(header) = headers.get(reqwest::header::CONNECTION) {
142            if !header
143                .to_str()
144                .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
145            {
146                tracing::debug!("server responded with invalid Connection header: {header:?}");
147                return Err(HandshakeError::UnexpectedHeaderValue {
148                    header: reqwest::header::CONNECTION,
149                    got: header.clone(),
150                    expected: "upgrade".into(),
151                }
152                .into());
153            }
154        } else {
155            tracing::debug!("missing Connection header");
156            return Err(HandshakeError::MissingHeader {
157                header: reqwest::header::CONNECTION,
158            }
159            .into());
160        }
161
162        if let Some(header) = headers.get(reqwest::header::UPGRADE) {
163            if !header
164                .to_str()
165                .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
166            {
167                tracing::debug!("server responded with invalid Upgrade header: {header:?}");
168                return Err(HandshakeError::UnexpectedHeaderValue {
169                    header: reqwest::header::UPGRADE,
170                    got: header.clone(),
171                    expected: "websocket".into(),
172                }
173                .into());
174            }
175        } else {
176            tracing::debug!("missing Upgrade header");
177            return Err(HandshakeError::MissingHeader {
178                header: reqwest::header::UPGRADE,
179            }
180            .into());
181        }
182
183        if let Some(nonce) = &self.nonce {
184            let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
185
186            if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
187                if !header.to_str().is_ok_and(|s| s == expected_nonce) {
188                    tracing::debug!(
189                        "server responded with invalid Sec-Websocket-Accept header: {header:?}"
190                    );
191                    return Err(HandshakeError::UnexpectedHeaderValue {
192                        header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
193                        got: header.clone(),
194                        expected: expected_nonce.into(),
195                    }
196                    .into());
197                }
198            } else {
199                tracing::debug!("missing Sec-Websocket-Accept header");
200                return Err(HandshakeError::MissingHeader {
201                    header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
202                }
203                .into());
204            }
205        }
206
207        let protocol = headers
208            .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
209            .and_then(|v| v.to_str().ok())
210            .map(ToOwned::to_owned);
211
212        match (protocols.is_empty(), &protocol) {
213            (true, None) => {
214                // we didn't request any protocols, so we don't expect one
215                // in return
216            }
217            (false, None) => {
218                // server didn't reply with a protocol
219                return Err(HandshakeError::ExpectedAProtocol.into());
220            }
221            (false, Some(protocol)) => {
222                if !protocols.contains(protocol) {
223                    // the responded protocol is none which we requested
224                    return Err(HandshakeError::UnexpectedProtocol {
225                        got: protocol.clone(),
226                    }
227                    .into());
228                }
229            }
230            (true, Some(protocol)) => {
231                // we didn't request any protocols but got one anyway
232                return Err(HandshakeError::UnexpectedProtocol {
233                    got: protocol.clone(),
234                }
235                .into());
236            }
237        }
238
239        use tokio_util::compat::TokioAsyncReadCompatExt;
240
241        let inner = WebSocketStream::from_raw_socket(
242            self.response.upgrade().await?.compat(),
243            tungstenite::protocol::Role::Client,
244            web_socket_config,
245        )
246        .await;
247
248        Ok((inner, protocol))
249    }
250}
251
252#[derive(Debug, thiserror::Error)]
253#[error("could not convert message")]
254pub struct FromTungsteniteMessageError {
255    pub original: tungstenite::Message,
256}
257
258impl TryFrom<tungstenite::Message> for Message {
259    type Error = FromTungsteniteMessageError;
260
261    fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
262        match value {
263            tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
264            tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
265            tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
266            tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
267            tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
268                code,
269                reason,
270            })) => Ok(Self::Close {
271                code: code.into(),
272                reason: reason.as_str().to_owned(),
273            }),
274            tungstenite::Message::Close(None) => Ok(Self::Close {
275                code: CloseCode::default(),
276                reason: "".to_owned(),
277            }),
278            tungstenite::Message::Frame(_) => Err(FromTungsteniteMessageError { original: value }),
279        }
280    }
281}
282
283impl From<Message> for tungstenite::Message {
284    fn from(value: Message) -> Self {
285        match value {
286            Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
287            Message::Binary(data) => Self::Binary(data),
288            Message::Ping(data) => Self::Ping(data),
289            Message::Pong(data) => Self::Pong(data),
290            Message::Close { code, reason } => {
291                Self::Close(Some(tungstenite::protocol::CloseFrame {
292                    code: code.into(),
293                    reason: reason.into(),
294                }))
295            }
296        }
297    }
298}
299
300impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
301    fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
302        u16::from(value).into()
303    }
304}
305
306impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
307    fn from(value: CloseCode) -> Self {
308        u16::from(value).into()
309    }
310}