1use std::borrow::Cow;
2
3use crate::{
4 protocol::{CloseCode, Message},
5 Client, Error, RequestBuilder,
6};
7use reqwest::{
8 header::{HeaderName, HeaderValue},
9 Response, StatusCode, Version,
10};
11use tungstenite::protocol::WebSocketConfig;
12
13pub async fn send_request<R>(
14 request_builder: R,
15 protocols: &[String],
16) -> Result<WebSocketResponse, Error>
17where
18 R: RequestBuilder,
19{
20 let (client, request_result) = request_builder.build_split();
21 let mut request = request_result?;
22
23 let url = request.url_mut();
25 match url.scheme() {
26 "ws" => {
27 url.set_scheme("http")
28 .expect("url should accept http scheme");
29 }
30 "wss" => {
31 url.set_scheme("https")
32 .expect("url should accept https scheme");
33 }
34 _ => {}
35 }
36
37 let version = request.version();
39 let nonce = match version {
40 Version::HTTP_10 | Version::HTTP_11 => {
41 let nonce_value = tungstenite::handshake::client::generate_key();
43 let headers = request.headers_mut();
44 headers.insert(
45 reqwest::header::CONNECTION,
46 HeaderValue::from_static("upgrade"),
47 );
48 headers.insert(
49 reqwest::header::UPGRADE,
50 HeaderValue::from_static("websocket"),
51 );
52 headers.insert(
53 reqwest::header::SEC_WEBSOCKET_KEY,
54 HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
55 );
56 headers.insert(
57 reqwest::header::SEC_WEBSOCKET_VERSION,
58 HeaderValue::from_static("13"),
59 );
60 if !protocols.is_empty() {
61 headers.insert(
62 reqwest::header::SEC_WEBSOCKET_PROTOCOL,
63 HeaderValue::from_str(&protocols.join(", "))
64 .expect("protocols is an invalid header value"),
65 );
66 }
67
68 Some(nonce_value)
69 }
70 Version::HTTP_2 => {
71 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
73 }
74 _ => {
75 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
76 }
77 };
78
79 let response = client.execute(request).await?;
81
82 Ok(WebSocketResponse {
83 response,
84 version,
85 nonce,
86 })
87}
88
89pub type WebSocketStream =
90 async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
91
92#[derive(Debug, thiserror::Error)]
94pub enum HandshakeError {
95 #[error("unsupported http version: {0:?}")]
96 UnsupportedHttpVersion(Version),
97
98 #[error("the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2")]
99 ServerRespondedWithDifferentVersion,
100
101 #[error("missing header {header}")]
102 MissingHeader { header: HeaderName },
103
104 #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
105 UnexpectedHeaderValue {
106 header: HeaderName,
107 got: HeaderValue,
108 expected: Cow<'static, str>,
109 },
110
111 #[error("expected the server to select a protocol.")]
112 ExpectedAProtocol,
113
114 #[error("unexpected protocol: {got}")]
115 UnexpectedProtocol { got: String },
116
117 #[error("unexpected status code: {0}")]
118 UnexpectedStatusCode(StatusCode),
119}
120
121pub struct WebSocketResponse {
122 pub response: Response,
123 pub version: Version,
124 pub nonce: Option<String>,
125}
126
127impl WebSocketResponse {
128 pub async fn into_stream_and_protocol(
129 self,
130 protocols: Vec<String>,
131 web_socket_config: Option<WebSocketConfig>,
132 ) -> Result<(WebSocketStream, Option<String>), Error> {
133 let headers = self.response.headers();
134
135 if self.response.version() != self.version {
136 return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
137 }
138
139 if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
140 tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
141 return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
142 }
143
144 if let Some(header) = headers.get(reqwest::header::CONNECTION) {
145 if !header
146 .to_str()
147 .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
148 {
149 tracing::debug!("server responded with invalid Connection header: {header:?}");
150 return Err(HandshakeError::UnexpectedHeaderValue {
151 header: reqwest::header::CONNECTION,
152 got: header.clone(),
153 expected: "upgrade".into(),
154 }
155 .into());
156 }
157 } else {
158 tracing::debug!("missing Connection header");
159 return Err(HandshakeError::MissingHeader {
160 header: reqwest::header::CONNECTION,
161 }
162 .into());
163 }
164
165 if let Some(header) = headers.get(reqwest::header::UPGRADE) {
166 if !header
167 .to_str()
168 .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
169 {
170 tracing::debug!("server responded with invalid Upgrade header: {header:?}");
171 return Err(HandshakeError::UnexpectedHeaderValue {
172 header: reqwest::header::UPGRADE,
173 got: header.clone(),
174 expected: "websocket".into(),
175 }
176 .into());
177 }
178 } else {
179 tracing::debug!("missing Upgrade header");
180 return Err(HandshakeError::MissingHeader {
181 header: reqwest::header::UPGRADE,
182 }
183 .into());
184 }
185
186 if let Some(nonce) = &self.nonce {
187 let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
188
189 if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
190 if !header.to_str().is_ok_and(|s| s == expected_nonce) {
191 tracing::debug!(
192 "server responded with invalid Sec-Websocket-Accept header: {header:?}"
193 );
194 return Err(HandshakeError::UnexpectedHeaderValue {
195 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
196 got: header.clone(),
197 expected: expected_nonce.into(),
198 }
199 .into());
200 }
201 } else {
202 tracing::debug!("missing Sec-Websocket-Accept header");
203 return Err(HandshakeError::MissingHeader {
204 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
205 }
206 .into());
207 }
208 }
209
210 let protocol = headers
211 .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
212 .and_then(|v| v.to_str().ok())
213 .map(ToOwned::to_owned);
214
215 match (protocols.is_empty(), &protocol) {
216 (true, None) => {
217 }
220 (false, None) => {
221 return Err(HandshakeError::ExpectedAProtocol.into());
223 }
224 (false, Some(protocol)) => {
225 if !protocols.contains(protocol) {
226 return Err(HandshakeError::UnexpectedProtocol {
228 got: protocol.clone(),
229 }
230 .into());
231 }
232 }
233 (true, Some(protocol)) => {
234 return Err(HandshakeError::UnexpectedProtocol {
236 got: protocol.clone(),
237 }
238 .into());
239 }
240 }
241
242 use tokio_util::compat::TokioAsyncReadCompatExt;
243
244 let inner = WebSocketStream::from_raw_socket(
245 self.response.upgrade().await?.compat(),
246 tungstenite::protocol::Role::Client,
247 web_socket_config,
248 )
249 .await;
250
251 Ok((inner, protocol))
252 }
253}
254
255#[derive(Debug, thiserror::Error)]
256#[error("could not convert message")]
257pub struct FromTungsteniteMessageError {
258 pub original: tungstenite::Message,
259}
260
261impl TryFrom<tungstenite::Message> for Message {
262 type Error = FromTungsteniteMessageError;
263
264 fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
265 match value {
266 tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
267 tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
268 tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
269 tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
270 tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
271 code,
272 reason,
273 })) => Ok(Self::Close {
274 code: code.into(),
275 reason: reason.as_str().to_owned(),
276 }),
277 tungstenite::Message::Close(None) => Ok(Self::Close {
278 code: CloseCode::default(),
279 reason: "".to_owned(),
280 }),
281 tungstenite::Message::Frame(_) => Err(FromTungsteniteMessageError { original: value }),
282 }
283 }
284}
285
286impl From<Message> for tungstenite::Message {
287 fn from(value: Message) -> Self {
288 match value {
289 Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
290 Message::Binary(data) => Self::Binary(data),
291 Message::Ping(data) => Self::Ping(data),
292 Message::Pong(data) => Self::Pong(data),
293 Message::Close { code, reason } => {
294 Self::Close(Some(tungstenite::protocol::CloseFrame {
295 code: code.into(),
296 reason: reason.into(),
297 }))
298 }
299 }
300 }
301}
302
303impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
304 fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
305 u16::from(value).into()
306 }
307}
308
309impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
310 fn from(value: CloseCode) -> Self {
311 u16::from(value).into()
312 }
313}