1use std::borrow::Cow;
2
3use crate::{
4 protocol::{CloseCode, Message},
5 Error,
6};
7use reqwest::{
8 header::{HeaderName, HeaderValue},
9 RequestBuilder, Response, StatusCode, Version,
10};
11use tungstenite::protocol::WebSocketConfig;
12
13pub async fn send_request(
14 request_builder: RequestBuilder,
15 protocols: &[String],
16) -> Result<WebSocketResponse, Error> {
17 let (client, request_result) = request_builder.build_split();
18 let mut request = request_result?;
19
20 let url = request.url_mut();
22 match url.scheme() {
23 "ws" => {
24 url.set_scheme("http")
25 .expect("url should accept http scheme");
26 }
27 "wss" => {
28 url.set_scheme("https")
29 .expect("url should accept https scheme");
30 }
31 _ => {}
32 }
33
34 let version = request.version();
36 let nonce = match version {
37 Version::HTTP_10 | Version::HTTP_11 => {
38 let nonce_value = tungstenite::handshake::client::generate_key();
40 let headers = request.headers_mut();
41 headers.insert(
42 reqwest::header::CONNECTION,
43 HeaderValue::from_static("upgrade"),
44 );
45 headers.insert(
46 reqwest::header::UPGRADE,
47 HeaderValue::from_static("websocket"),
48 );
49 headers.insert(
50 reqwest::header::SEC_WEBSOCKET_KEY,
51 HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
52 );
53 headers.insert(
54 reqwest::header::SEC_WEBSOCKET_VERSION,
55 HeaderValue::from_static("13"),
56 );
57 if !protocols.is_empty() {
58 headers.insert(
59 reqwest::header::SEC_WEBSOCKET_PROTOCOL,
60 HeaderValue::from_str(&protocols.join(", "))
61 .expect("protocols is an invalid header value"),
62 );
63 }
64
65 Some(nonce_value)
66 }
67 Version::HTTP_2 => {
68 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
70 }
71 _ => {
72 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
73 }
74 };
75
76 let response = client.execute(request).await?;
78
79 Ok(WebSocketResponse {
80 response,
81 version,
82 nonce,
83 })
84}
85
86pub type WebSocketStream =
87 async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
88
89#[derive(Debug, thiserror::Error)]
91pub enum HandshakeError {
92 #[error("unsupported http version: {0:?}")]
93 UnsupportedHttpVersion(Version),
94
95 #[error("the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2")]
96 ServerRespondedWithDifferentVersion,
97
98 #[error("missing header {header}")]
99 MissingHeader { header: HeaderName },
100
101 #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
102 UnexpectedHeaderValue {
103 header: HeaderName,
104 got: HeaderValue,
105 expected: Cow<'static, str>,
106 },
107
108 #[error("expected the server to select a protocol.")]
109 ExpectedAProtocol,
110
111 #[error("unexpected protocol: {got}")]
112 UnexpectedProtocol { got: String },
113
114 #[error("unexpected status code: {0}")]
115 UnexpectedStatusCode(StatusCode),
116}
117
118pub struct WebSocketResponse {
119 pub response: Response,
120 pub version: Version,
121 pub nonce: Option<String>,
122}
123
124impl WebSocketResponse {
125 pub async fn into_stream_and_protocol(
126 self,
127 protocols: Vec<String>,
128 web_socket_config: Option<WebSocketConfig>,
129 ) -> Result<(WebSocketStream, Option<String>), Error> {
130 let headers = self.response.headers();
131
132 if self.response.version() != self.version {
133 return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
134 }
135
136 if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
137 tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
138 return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
139 }
140
141 if let Some(header) = headers.get(reqwest::header::CONNECTION) {
142 if !header
143 .to_str()
144 .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
145 {
146 tracing::debug!("server responded with invalid Connection header: {header:?}");
147 return Err(HandshakeError::UnexpectedHeaderValue {
148 header: reqwest::header::CONNECTION,
149 got: header.clone(),
150 expected: "upgrade".into(),
151 }
152 .into());
153 }
154 } else {
155 tracing::debug!("missing Connection header");
156 return Err(HandshakeError::MissingHeader {
157 header: reqwest::header::CONNECTION,
158 }
159 .into());
160 }
161
162 if let Some(header) = headers.get(reqwest::header::UPGRADE) {
163 if !header
164 .to_str()
165 .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
166 {
167 tracing::debug!("server responded with invalid Upgrade header: {header:?}");
168 return Err(HandshakeError::UnexpectedHeaderValue {
169 header: reqwest::header::UPGRADE,
170 got: header.clone(),
171 expected: "websocket".into(),
172 }
173 .into());
174 }
175 } else {
176 tracing::debug!("missing Upgrade header");
177 return Err(HandshakeError::MissingHeader {
178 header: reqwest::header::UPGRADE,
179 }
180 .into());
181 }
182
183 if let Some(nonce) = &self.nonce {
184 let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
185
186 if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
187 if !header.to_str().is_ok_and(|s| s == expected_nonce) {
188 tracing::debug!(
189 "server responded with invalid Sec-Websocket-Accept header: {header:?}"
190 );
191 return Err(HandshakeError::UnexpectedHeaderValue {
192 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
193 got: header.clone(),
194 expected: expected_nonce.into(),
195 }
196 .into());
197 }
198 } else {
199 tracing::debug!("missing Sec-Websocket-Accept header");
200 return Err(HandshakeError::MissingHeader {
201 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
202 }
203 .into());
204 }
205 }
206
207 let protocol = headers
208 .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
209 .and_then(|v| v.to_str().ok())
210 .map(ToOwned::to_owned);
211
212 match (protocols.is_empty(), &protocol) {
213 (true, None) => {
214 }
217 (false, None) => {
218 return Err(HandshakeError::ExpectedAProtocol.into());
220 }
221 (false, Some(protocol)) => {
222 if !protocols.contains(protocol) {
223 return Err(HandshakeError::UnexpectedProtocol {
225 got: protocol.clone(),
226 }
227 .into());
228 }
229 }
230 (true, Some(protocol)) => {
231 return Err(HandshakeError::UnexpectedProtocol {
233 got: protocol.clone(),
234 }
235 .into());
236 }
237 }
238
239 use tokio_util::compat::TokioAsyncReadCompatExt;
240
241 let inner = WebSocketStream::from_raw_socket(
242 self.response.upgrade().await?.compat(),
243 tungstenite::protocol::Role::Client,
244 web_socket_config,
245 )
246 .await;
247
248 Ok((inner, protocol))
249 }
250}
251
252#[derive(Debug, thiserror::Error)]
253#[error("could not convert message")]
254pub struct FromTungsteniteMessageError {
255 pub original: tungstenite::Message,
256}
257
258impl TryFrom<tungstenite::Message> for Message {
259 type Error = FromTungsteniteMessageError;
260
261 fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
262 match value {
263 tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
264 tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
265 tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
266 tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
267 tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
268 code,
269 reason,
270 })) => Ok(Self::Close {
271 code: code.into(),
272 reason: reason.as_str().to_owned(),
273 }),
274 tungstenite::Message::Close(None) => Ok(Self::Close {
275 code: CloseCode::default(),
276 reason: "".to_owned(),
277 }),
278 tungstenite::Message::Frame(_) => Err(FromTungsteniteMessageError { original: value }),
279 }
280 }
281}
282
283impl From<Message> for tungstenite::Message {
284 fn from(value: Message) -> Self {
285 match value {
286 Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
287 Message::Binary(data) => Self::Binary(data),
288 Message::Ping(data) => Self::Ping(data),
289 Message::Pong(data) => Self::Pong(data),
290 Message::Close { code, reason } => {
291 Self::Close(Some(tungstenite::protocol::CloseFrame {
292 code: code.into(),
293 reason: reason.into(),
294 }))
295 }
296 }
297 }
298}
299
300impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
301 fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
302 u16::from(value).into()
303 }
304}
305
306impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
307 fn from(value: CloseCode) -> Self {
308 u16::from(value).into()
309 }
310}