reqwest_websocket/
lib.rs

1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! Provides wrappers for [`reqwest`][2] to enable [`WebSocket`][1] connections.
5//!
6//! # Example
7//!
8//! ```
9//! # use reqwest::Client;
10//! # use reqwest_websocket::{Message, Error};
11//! # use futures_util::{TryStreamExt, SinkExt};
12//! #
13//! # fn main() {
14//! #     // Intentionally ignore the future. We only care that it compiles.
15//! #     let _ = run();
16//! # }
17//! #
18//! # async fn run() -> Result<(), Error> {
19//! // Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
20//! use reqwest_websocket::RequestBuilderExt;
21//!
22//! // Creates a GET request, upgrades and sends it.
23//! let response = Client::default()
24//!     .get("wss://echo.websocket.org/")
25//!     .upgrade() // Prepares the WebSocket upgrade.
26//!     .send()
27//!     .await?;
28//!
29//! // Turns the response into a WebSocket stream.
30//! let mut websocket = response.into_websocket().await?;
31//!
32//! // The WebSocket implements `Sink<Message>`.
33//! websocket.send(Message::Text("Hello, World".into())).await?;
34//!
35//! // The WebSocket is also a `TryStream` over `Message`s.
36//! while let Some(message) = websocket.try_next().await? {
37//!     if let Message::Text(text) = message {
38//!         println!("received: {text}")
39//!     }
40//! }
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! [1]: https://en.wikipedia.org/wiki/WebSocket
46//! [2]: https://docs.rs/reqwest/latest/reqwest/index.html
47
48#[cfg(feature = "json")]
49mod json;
50#[cfg(not(target_arch = "wasm32"))]
51mod native;
52mod protocol;
53#[cfg(target_arch = "wasm32")]
54mod wasm;
55
56use std::{
57    pin::Pin,
58    task::{ready, Context, Poll},
59};
60
61#[cfg(not(target_arch = "wasm32"))]
62#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
63pub use crate::native::HandshakeError;
64pub use crate::protocol::{CloseCode, Message};
65pub use bytes::Bytes;
66use futures_util::{Sink, SinkExt, Stream, StreamExt};
67use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
68
69/// Errors returned by `reqwest_websocket`.
70#[derive(Debug, thiserror::Error)]
71#[non_exhaustive]
72pub enum Error {
73    #[cfg(not(target_arch = "wasm32"))]
74    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
75    #[error("websocket upgrade failed")]
76    Handshake(#[from] HandshakeError),
77
78    #[error("reqwest error")]
79    Reqwest(#[from] reqwest::Error),
80
81    #[cfg(not(target_arch = "wasm32"))]
82    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
83    #[error("tungstenite error")]
84    Tungstenite(#[from] tungstenite::Error),
85
86    #[cfg(target_arch = "wasm32")]
87    #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
88    #[error("web_sys error")]
89    WebSys(#[from] wasm::WebSysError),
90
91    /// Error during serialization/deserialization.
92    #[error("serde_json error")]
93    #[cfg(feature = "json")]
94    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
95    Json(#[from] serde_json::Error),
96}
97
98/// Opens a `WebSocket` connection at the specified `URL`.
99///
100/// This is a shorthand for creating a [`Request`], sending it, and turning the
101/// [`Response`] into a [`WebSocket`].
102///
103/// [`Request`]: reqwest::Request
104/// [`Response`]: reqwest::Response
105pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
106    builder_http1_only(Client::builder())
107        .build()?
108        .get(url)
109        .upgrade()
110        .send()
111        .await?
112        .into_websocket()
113        .await
114}
115
116#[inline]
117#[cfg(not(target_arch = "wasm32"))]
118fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
119    builder.http1_only()
120}
121
122#[inline]
123#[cfg(target_arch = "wasm32")]
124fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
125    builder
126}
127
128/// Trait that extends [`reqwest::RequestBuilder`] with an `upgrade` method.
129pub trait RequestBuilderExt {
130    /// Upgrades the [`RequestBuilder`] to perform a `WebSocket` handshake.
131    ///
132    /// This returns a wrapped type, so you must do this after you set up
133    /// your request, and just before sending the request.
134    fn upgrade(self) -> UpgradedRequestBuilder;
135}
136
137impl RequestBuilderExt for RequestBuilder {
138    fn upgrade(self) -> UpgradedRequestBuilder {
139        UpgradedRequestBuilder::new(self)
140    }
141}
142
143/// Wrapper for a [`reqwest::RequestBuilder`] that performs the
144/// `WebSocket` handshake when sent.
145pub struct UpgradedRequestBuilder {
146    inner: RequestBuilder,
147    protocols: Vec<String>,
148    #[cfg(not(target_arch = "wasm32"))]
149    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
150}
151
152impl UpgradedRequestBuilder {
153    pub(crate) fn new(inner: RequestBuilder) -> Self {
154        Self {
155            inner,
156            protocols: vec![],
157            #[cfg(not(target_arch = "wasm32"))]
158            web_socket_config: None,
159        }
160    }
161
162    /// Selects which sub-protocols are accepted by the client.
163    pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
164        self.protocols = protocols.into_iter().map(Into::into).collect();
165
166        self
167    }
168
169    /// Sets the WebSocket configuration.
170    #[cfg(not(target_arch = "wasm32"))]
171    pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
172        self.web_socket_config = Some(config);
173        self
174    }
175
176    /// Sends the request and returns an [`UpgradeResponse`].
177    pub async fn send(self) -> Result<UpgradeResponse, Error> {
178        #[cfg(not(target_arch = "wasm32"))]
179        let inner = native::send_request(self.inner, &self.protocols).await?;
180
181        #[cfg(target_arch = "wasm32")]
182        let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
183
184        Ok(UpgradeResponse {
185            inner,
186            protocols: self.protocols,
187            #[cfg(not(target_arch = "wasm32"))]
188            web_socket_config: self.web_socket_config,
189        })
190    }
191}
192
193/// The server's response to the `WebSocket` upgrade request.
194///
195/// On non-wasm platforms, this implements `Deref<Target = Response>`, so you
196/// can access all the usual information from the [`reqwest::Response`].
197pub struct UpgradeResponse {
198    #[cfg(not(target_arch = "wasm32"))]
199    inner: native::WebSocketResponse,
200
201    #[cfg(target_arch = "wasm32")]
202    inner: wasm::WebSysWebSocketStream,
203
204    #[allow(dead_code)]
205    protocols: Vec<String>,
206
207    #[cfg(not(target_arch = "wasm32"))]
208    #[allow(dead_code)]
209    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
210}
211
212#[cfg(not(target_arch = "wasm32"))]
213impl std::ops::Deref for UpgradeResponse {
214    type Target = reqwest::Response;
215
216    fn deref(&self) -> &Self::Target {
217        &self.inner.response
218    }
219}
220
221impl UpgradeResponse {
222    /// Turns the response into a `WebSocket`.
223    /// This checks if the `WebSocket` handshake was successful.
224    pub async fn into_websocket(self) -> Result<WebSocket, Error> {
225        #[cfg(not(target_arch = "wasm32"))]
226        let (inner, protocol) = self
227            .inner
228            .into_stream_and_protocol(self.protocols, self.web_socket_config)
229            .await?;
230
231        #[cfg(target_arch = "wasm32")]
232        let (inner, protocol) = {
233            let protocol = self.inner.protocol();
234            (self.inner, Some(protocol))
235        };
236
237        Ok(WebSocket { inner, protocol })
238    }
239
240    /// Consumes the response and returns the inner [`reqwest::Response`].
241    #[must_use]
242    #[cfg(not(target_arch = "wasm32"))]
243    pub fn into_inner(self) -> reqwest::Response {
244        self.inner.response
245    }
246}
247
248/// A `WebSocket` connection. Implements [`futures_util::Stream`] and
249/// [`futures_util::Sink`].
250#[derive(Debug)]
251pub struct WebSocket {
252    #[cfg(not(target_arch = "wasm32"))]
253    inner: native::WebSocketStream,
254
255    #[cfg(target_arch = "wasm32")]
256    inner: wasm::WebSysWebSocketStream,
257
258    protocol: Option<String>,
259}
260
261impl WebSocket {
262    /// Returns the protocol negotiated during the handshake.
263    pub fn protocol(&self) -> Option<&str> {
264        self.protocol.as_deref()
265    }
266
267    /// Closes the connection with a given code and (optional) reason.
268    ///
269    /// # WASM
270    ///
271    /// On wasm `code` must be [`CloseCode::Normal`], [`CloseCode::Iana(_)`],
272    /// or [`CloseCode::Library(_)`]. Furthermore `reason` must be at most 123
273    /// bytes long. Otherwise the call to [`close`][Self::close] will fail.
274    pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
275        #[cfg(not(target_arch = "wasm32"))]
276        {
277            let mut inner = self.inner;
278            inner
279                .close(Some(tungstenite::protocol::CloseFrame {
280                    code: code.into(),
281                    reason: reason.unwrap_or_default().into(),
282                }))
283                .await?;
284        }
285
286        #[cfg(target_arch = "wasm32")]
287        self.inner.close(code.into(), reason.unwrap_or_default())?;
288
289        Ok(())
290    }
291}
292
293impl Stream for WebSocket {
294    type Item = Result<Message, Error>;
295
296    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        match ready!(self.inner.poll_next_unpin(cx)) {
298            None => Poll::Ready(None),
299            Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
300            Some(Ok(message)) => {
301                match message.try_into() {
302                    Ok(message) => Poll::Ready(Some(Ok(message))),
303
304                    #[cfg(target_arch = "wasm32")]
305                    Err(e) => match e {},
306
307                    #[cfg(not(target_arch = "wasm32"))]
308                    Err(e) => {
309                        // this fails only for raw frames (which are not received)
310                        panic!("Received an invalid frame: {e}");
311                    }
312                }
313            }
314        }
315    }
316}
317
318impl Sink<Message> for WebSocket {
319    type Error = Error;
320
321    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322        self.inner.poll_ready_unpin(cx).map_err(Into::into)
323    }
324
325    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
326        self.inner.start_send_unpin(item.into()).map_err(Into::into)
327    }
328
329    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330        self.inner.poll_flush_unpin(cx).map_err(Into::into)
331    }
332
333    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334        self.inner.poll_close_unpin(cx).map_err(Into::into)
335    }
336}
337
338#[cfg(test)]
339pub mod tests {
340    use futures_util::{SinkExt, TryStreamExt};
341    use reqwest::Client;
342    #[cfg(target_arch = "wasm32")]
343    use wasm_bindgen_test::wasm_bindgen_test;
344
345    #[cfg(target_arch = "wasm32")]
346    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
347
348    use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
349
350    async fn test_websocket(mut websocket: WebSocket) {
351        let text = "Hello, World!";
352        websocket.send(Message::Text(text.into())).await.unwrap();
353
354        while let Some(message) = websocket.try_next().await.unwrap() {
355            match message {
356                Message::Text(s) => {
357                    if s == text {
358                        return;
359                    }
360                }
361                _ => {}
362            }
363        }
364
365        panic!("didn't receive text back");
366    }
367
368    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
369    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
370    async fn test_with_request_builder() {
371        let websocket = Client::default()
372            .get("https://echo.websocket.org/")
373            .upgrade()
374            .send()
375            .await
376            .unwrap()
377            .into_websocket()
378            .await
379            .unwrap();
380
381        test_websocket(websocket).await;
382    }
383
384    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
385    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
386    async fn test_shorthand() {
387        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
388        test_websocket(websocket).await;
389    }
390
391    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
392    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
393    async fn test_with_ws_scheme() {
394        let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
395
396        test_websocket(websocket).await;
397    }
398
399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
400    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
401    async fn test_close() {
402        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
403        websocket
404            .close(CloseCode::Normal, Some("test"))
405            .await
406            .expect("close returned an error");
407    }
408
409    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
410    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
411    async fn test_send_close_frame() {
412        let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
413        websocket
414            .send(Message::Close {
415                code: CloseCode::Normal,
416                reason: "Can you please reply with a close frame?".into(),
417            })
418            .await
419            .unwrap();
420
421        let mut close_received = false;
422        while let Some(message) = websocket.try_next().await.unwrap() {
423            match message {
424                Message::Close { code, .. } => {
425                    assert_eq!(code, CloseCode::Normal);
426                    close_received = true;
427                }
428                _ => {}
429            }
430        }
431
432        assert!(close_received, "No close frame was received");
433    }
434
435    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
436    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
437    #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
438    async fn test_with_subprotocol() {
439        let websocket = Client::default()
440            .get("https://echo.websocket.org/")
441            .upgrade()
442            .protocols(["chat"])
443            .send()
444            .await
445            .unwrap()
446            .into_websocket()
447            .await
448            .unwrap();
449
450        assert_eq!(websocket.protocol(), Some("chat"));
451
452        test_websocket(websocket).await;
453    }
454
455    #[test]
456    fn close_code_from_u16() {
457        let byte = 1008u16;
458        assert_eq!(CloseCode::from(byte), CloseCode::Policy);
459    }
460
461    #[test]
462    fn close_code_into_u16() {
463        let text = CloseCode::Away;
464        let byte: u16 = text.into();
465        assert_eq!(byte, 1001u16);
466        assert_eq!(u16::from(text), 1001u16);
467    }
468}