reqwest_websocket/
lib.rs

1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! Provides wrappers for [`reqwest`][2] to enable [`WebSocket`][1] connections.
5//!
6//! # Example
7//!
8//! ```
9//! # use reqwest::Client;
10//! # use reqwest_websocket::{Message, Error};
11//! # use futures_util::{TryStreamExt, SinkExt};
12//! #
13//! # fn main() {
14//! #     // Intentionally ignore the future. We only care that it compiles.
15//! #     let _ = run();
16//! # }
17//! #
18//! # async fn run() -> Result<(), Error> {
19//! // Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
20//! use reqwest_websocket::RequestBuilderExt;
21//!
22//! // Creates a GET request, upgrades and sends it.
23//! let response = Client::default()
24//!     .get("wss://echo.websocket.org/")
25//!     .upgrade() // Prepares the WebSocket upgrade.
26//!     .send()
27//!     .await?;
28//!
29//! // Turns the response into a WebSocket stream.
30//! let mut websocket = response.into_websocket().await?;
31//!
32//! // The WebSocket implements `Sink<Message>`.
33//! websocket.send(Message::Text("Hello, World".into())).await?;
34//!
35//! // The WebSocket is also a `TryStream` over `Message`s.
36//! while let Some(message) = websocket.try_next().await? {
37//!     if let Message::Text(text) = message {
38//!         println!("received: {text}")
39//!     }
40//! }
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! [1]: https://en.wikipedia.org/wiki/WebSocket
46//! [2]: https://docs.rs/reqwest/latest/reqwest/index.html
47
48#[cfg(feature = "json")]
49mod json;
50#[cfg(not(target_arch = "wasm32"))]
51mod native;
52mod protocol;
53#[cfg(target_arch = "wasm32")]
54mod wasm;
55
56use std::{
57    pin::Pin,
58    task::{ready, Context, Poll},
59};
60
61#[cfg(not(target_arch = "wasm32"))]
62#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
63pub use crate::native::HandshakeError;
64pub use crate::protocol::{CloseCode, Message};
65use futures_util::{Sink, SinkExt, Stream, StreamExt};
66use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
67
68/// Errors returned by `reqwest_websocket`.
69#[derive(Debug, thiserror::Error)]
70#[non_exhaustive]
71pub enum Error {
72    #[cfg(not(target_arch = "wasm32"))]
73    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
74    #[error("websocket upgrade failed")]
75    Handshake(#[from] HandshakeError),
76
77    #[error("reqwest error")]
78    Reqwest(#[from] reqwest::Error),
79
80    #[cfg(not(target_arch = "wasm32"))]
81    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
82    #[error("tungstenite error")]
83    Tungstenite(#[from] tungstenite::Error),
84
85    #[cfg(target_arch = "wasm32")]
86    #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
87    #[error("web_sys error")]
88    WebSys(#[from] wasm::WebSysError),
89
90    /// Error during serialization/deserialization.
91    #[error("serde_json error")]
92    #[cfg(feature = "json")]
93    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
94    Json(#[from] serde_json::Error),
95}
96
97/// Opens a `WebSocket` connection at the specified `URL`.
98///
99/// This is a shorthand for creating a [`Request`], sending it, and turning the
100/// [`Response`] into a [`WebSocket`].
101///
102/// [`Request`]: reqwest::Request
103/// [`Response`]: reqwest::Response
104pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
105    builder_http1_only(Client::builder())
106        .build()?
107        .get(url)
108        .upgrade()
109        .send()
110        .await?
111        .into_websocket()
112        .await
113}
114
115#[inline]
116#[cfg(not(target_arch = "wasm32"))]
117fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
118    builder.http1_only()
119}
120
121#[inline]
122#[cfg(target_arch = "wasm32")]
123fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
124    builder
125}
126
127/// Trait that extends [`reqwest::RequestBuilder`] with an `upgrade` method.
128pub trait RequestBuilderExt {
129    /// Upgrades the [`RequestBuilder`] to perform a `WebSocket` handshake.
130    ///
131    /// This returns a wrapped type, so you must do this after you set up
132    /// your request, and just before sending the request.
133    fn upgrade(self) -> UpgradedRequestBuilder;
134}
135
136impl RequestBuilderExt for RequestBuilder {
137    fn upgrade(self) -> UpgradedRequestBuilder {
138        UpgradedRequestBuilder::new(self)
139    }
140}
141
142/// Wrapper for a [`reqwest::RequestBuilder`] that performs the
143/// `WebSocket` handshake when sent.
144pub struct UpgradedRequestBuilder {
145    inner: RequestBuilder,
146    protocols: Vec<String>,
147    #[cfg(not(target_arch = "wasm32"))]
148    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
149}
150
151impl UpgradedRequestBuilder {
152    pub(crate) fn new(inner: RequestBuilder) -> Self {
153        Self {
154            inner,
155            protocols: vec![],
156            #[cfg(not(target_arch = "wasm32"))]
157            web_socket_config: None,
158        }
159    }
160
161    /// Selects which sub-protocols are accepted by the client.
162    pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
163        self.protocols = protocols.into_iter().map(Into::into).collect();
164
165        self
166    }
167
168    /// Sets the WebSocket configuration.
169    #[cfg(not(target_arch = "wasm32"))]
170    pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
171        self.web_socket_config = Some(config);
172        self
173    }
174
175    /// Sends the request and returns an [`UpgradeResponse`].
176    pub async fn send(self) -> Result<UpgradeResponse, Error> {
177        #[cfg(not(target_arch = "wasm32"))]
178        let inner = native::send_request(self.inner, &self.protocols).await?;
179
180        #[cfg(target_arch = "wasm32")]
181        let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
182
183        Ok(UpgradeResponse {
184            inner,
185            protocols: self.protocols,
186            #[cfg(not(target_arch = "wasm32"))]
187            web_socket_config: self.web_socket_config,
188        })
189    }
190}
191
192/// The server's response to the `WebSocket` upgrade request.
193///
194/// On non-wasm platforms, this implements `Deref<Target = Response>`, so you
195/// can access all the usual information from the [`reqwest::Response`].
196pub struct UpgradeResponse {
197    #[cfg(not(target_arch = "wasm32"))]
198    inner: native::WebSocketResponse,
199
200    #[cfg(target_arch = "wasm32")]
201    inner: wasm::WebSysWebSocketStream,
202
203    #[allow(dead_code)]
204    protocols: Vec<String>,
205
206    #[cfg(not(target_arch = "wasm32"))]
207    #[allow(dead_code)]
208    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
209}
210
211#[cfg(not(target_arch = "wasm32"))]
212impl std::ops::Deref for UpgradeResponse {
213    type Target = reqwest::Response;
214
215    fn deref(&self) -> &Self::Target {
216        &self.inner.response
217    }
218}
219
220impl UpgradeResponse {
221    /// Turns the response into a `WebSocket`.
222    /// This checks if the `WebSocket` handshake was successful.
223    pub async fn into_websocket(self) -> Result<WebSocket, Error> {
224        #[cfg(not(target_arch = "wasm32"))]
225        let (inner, protocol) = self
226            .inner
227            .into_stream_and_protocol(self.protocols, self.web_socket_config)
228            .await?;
229
230        #[cfg(target_arch = "wasm32")]
231        let (inner, protocol) = {
232            let protocol = self.inner.protocol();
233            (self.inner, Some(protocol))
234        };
235
236        Ok(WebSocket { inner, protocol })
237    }
238
239    /// Consumes the response and returns the inner [`reqwest::Response`].
240    #[must_use]
241    #[cfg(not(target_arch = "wasm32"))]
242    pub fn into_inner(self) -> reqwest::Response {
243        self.inner.response
244    }
245}
246
247/// A `WebSocket` connection. Implements [`futures_util::Stream`] and
248/// [`futures_util::Sink`].
249#[derive(Debug)]
250pub struct WebSocket {
251    #[cfg(not(target_arch = "wasm32"))]
252    inner: native::WebSocketStream,
253
254    #[cfg(target_arch = "wasm32")]
255    inner: wasm::WebSysWebSocketStream,
256
257    protocol: Option<String>,
258}
259
260impl WebSocket {
261    /// Returns the protocol negotiated during the handshake.
262    pub fn protocol(&self) -> Option<&str> {
263        self.protocol.as_deref()
264    }
265
266    /// Closes the connection with a given code and (optional) reason.
267    ///
268    /// # WASM
269    ///
270    /// On wasm `code` must be [`CloseCode::Normal`], [`CloseCode::Iana(_)`],
271    /// or [`CloseCode::Library(_)`]. Furthermore `reason` must be at most 123
272    /// bytes long. Otherwise the call to [`close`][Self::close] will fail.
273    pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
274        #[cfg(not(target_arch = "wasm32"))]
275        {
276            let mut inner = self.inner;
277            inner
278                .close(Some(tungstenite::protocol::CloseFrame {
279                    code: code.into(),
280                    reason: reason.unwrap_or_default().into(),
281                }))
282                .await?;
283        }
284
285        #[cfg(target_arch = "wasm32")]
286        self.inner.close(code.into(), reason.unwrap_or_default())?;
287
288        Ok(())
289    }
290}
291
292impl Stream for WebSocket {
293    type Item = Result<Message, Error>;
294
295    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296        match ready!(self.inner.poll_next_unpin(cx)) {
297            None => Poll::Ready(None),
298            Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
299            Some(Ok(message)) => {
300                match message.try_into() {
301                    Ok(message) => Poll::Ready(Some(Ok(message))),
302
303                    #[cfg(target_arch = "wasm32")]
304                    Err(e) => match e {},
305
306                    #[cfg(not(target_arch = "wasm32"))]
307                    Err(e) => {
308                        // this fails only for raw frames (which are not received)
309                        panic!("Received an invalid frame: {e}");
310                    }
311                }
312            }
313        }
314    }
315}
316
317impl Sink<Message> for WebSocket {
318    type Error = Error;
319
320    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321        self.inner.poll_ready_unpin(cx).map_err(Into::into)
322    }
323
324    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
325        self.inner.start_send_unpin(item.into()).map_err(Into::into)
326    }
327
328    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
329        self.inner.poll_flush_unpin(cx).map_err(Into::into)
330    }
331
332    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
333        self.inner.poll_close_unpin(cx).map_err(Into::into)
334    }
335}
336
337#[cfg(test)]
338pub mod tests {
339    use futures_util::{SinkExt, TryStreamExt};
340    use reqwest::Client;
341    #[cfg(target_arch = "wasm32")]
342    use wasm_bindgen_test::wasm_bindgen_test;
343
344    #[cfg(target_arch = "wasm32")]
345    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
346
347    use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
348
349    async fn test_websocket(mut websocket: WebSocket) {
350        let text = "Hello, World!";
351        websocket
352            .send(Message::Text(text.to_owned()))
353            .await
354            .unwrap();
355
356        while let Some(message) = websocket.try_next().await.unwrap() {
357            match message {
358                Message::Text(s) => {
359                    if s == text {
360                        return;
361                    }
362                }
363                _ => {}
364            }
365        }
366
367        panic!("didn't receive text back");
368    }
369
370    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
371    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
372    async fn test_with_request_builder() {
373        let websocket = Client::default()
374            .get("https://echo.websocket.org/")
375            .upgrade()
376            .send()
377            .await
378            .unwrap()
379            .into_websocket()
380            .await
381            .unwrap();
382
383        test_websocket(websocket).await;
384    }
385
386    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
387    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
388    async fn test_shorthand() {
389        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
390        test_websocket(websocket).await;
391    }
392
393    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
394    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
395    async fn test_with_ws_scheme() {
396        let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
397
398        test_websocket(websocket).await;
399    }
400
401    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
402    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
403    async fn test_close() {
404        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
405        websocket
406            .close(CloseCode::Normal, Some("test"))
407            .await
408            .expect("close returned an error");
409    }
410
411    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
412    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
413    async fn test_send_close_frame() {
414        let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
415        websocket
416            .send(Message::Close {
417                code: CloseCode::Normal,
418                reason: "Can you please reply with a close frame?".into(),
419            })
420            .await
421            .unwrap();
422
423        let mut close_received = false;
424        while let Some(message) = websocket.try_next().await.unwrap() {
425            match message {
426                Message::Close { code, .. } => {
427                    assert_eq!(code, CloseCode::Normal);
428                    close_received = true;
429                }
430                _ => {}
431            }
432        }
433
434        assert!(close_received, "No close frame was received");
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
438    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
439    #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
440    async fn test_with_subprotocol() {
441        let websocket = Client::default()
442            .get("https://echo.websocket.org/")
443            .upgrade()
444            .protocols(["chat"])
445            .send()
446            .await
447            .unwrap()
448            .into_websocket()
449            .await
450            .unwrap();
451
452        assert_eq!(websocket.protocol(), Some("chat"));
453
454        test_websocket(websocket).await;
455    }
456
457    #[test]
458    fn close_code_from_u16() {
459        let byte = 1008u16;
460        assert_eq!(CloseCode::from(byte), CloseCode::Policy);
461    }
462
463    #[test]
464    fn close_code_into_u16() {
465        let text = CloseCode::Away;
466        let byte: u16 = text.into();
467        assert_eq!(byte, 1001u16);
468        assert_eq!(u16::from(text), 1001u16);
469    }
470}