reqwest_websocket/
lib.rs

1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3// note: the tungstenite error variant got bigger, so now clippy complains. but it's only 136 bytes, so I think it's fine. Boxing this would require a breaking change.
4#![allow(clippy::result_large_err)]
5
6//! Provides wrappers for [`reqwest`][2] to enable [`WebSocket`][1] connections.
7//!
8//! # Example
9//!
10//! ```
11//! # use reqwest::Client;
12//! # use reqwest_websocket::{Message, Error};
13//! # use futures_util::{TryStreamExt, SinkExt};
14//! #
15//! # fn main() {
16//! #     // Intentionally ignore the future. We only care that it compiles.
17//! #     let _ = run();
18//! # }
19//! #
20//! # async fn run() -> Result<(), Error> {
21//! // Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
22//! use reqwest_websocket::RequestBuilderExt;
23//!
24//! // Creates a GET request, upgrades and sends it.
25//! let response = Client::default()
26//!     .get("wss://echo.websocket.org/")
27//!     .upgrade() // Prepares the WebSocket upgrade.
28//!     .send()
29//!     .await?;
30//!
31//! // Turns the response into a WebSocket stream.
32//! let mut websocket = response.into_websocket().await?;
33//!
34//! // The WebSocket implements `Sink<Message>`.
35//! websocket.send(Message::Text("Hello, World".into())).await?;
36//!
37//! // The WebSocket is also a `TryStream` over `Message`s.
38//! while let Some(message) = websocket.try_next().await? {
39//!     if let Message::Text(text) = message {
40//!         println!("received: {text}")
41//!     }
42//! }
43//! # Ok(())
44//! # }
45//! ```
46//!
47//! [1]: https://en.wikipedia.org/wiki/WebSocket
48//! [2]: https://docs.rs/reqwest/latest/reqwest/index.html
49
50#[cfg(feature = "json")]
51mod json;
52#[cfg(not(target_arch = "wasm32"))]
53mod native;
54mod protocol;
55#[cfg(target_arch = "wasm32")]
56mod wasm;
57
58use std::{
59    pin::Pin,
60    task::{ready, Context, Poll},
61};
62
63#[cfg(not(target_arch = "wasm32"))]
64#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
65pub use crate::native::HandshakeError;
66pub use crate::protocol::{CloseCode, Message};
67pub use bytes::Bytes;
68use futures_util::{Sink, SinkExt, Stream, StreamExt};
69use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
70
71/// Errors returned by `reqwest_websocket`.
72#[derive(Debug, thiserror::Error)]
73#[non_exhaustive]
74pub enum Error {
75    #[cfg(not(target_arch = "wasm32"))]
76    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
77    #[error("websocket upgrade failed")]
78    Handshake(#[from] HandshakeError),
79
80    #[error("reqwest error")]
81    Reqwest(#[from] reqwest::Error),
82
83    #[cfg(not(target_arch = "wasm32"))]
84    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
85    #[error("tungstenite error")]
86    Tungstenite(#[from] tungstenite::Error),
87
88    #[cfg(target_arch = "wasm32")]
89    #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
90    #[error("web_sys error")]
91    WebSys(#[from] wasm::WebSysError),
92
93    /// Error during serialization/deserialization.
94    #[error("serde_json error")]
95    #[cfg(feature = "json")]
96    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
97    Json(#[from] serde_json::Error),
98}
99
100/// Opens a `WebSocket` connection at the specified `URL`.
101///
102/// This is a shorthand for creating a [`Request`], sending it, and turning the
103/// [`Response`] into a [`WebSocket`].
104///
105/// [`Request`]: reqwest::Request
106/// [`Response`]: reqwest::Response
107pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
108    builder_http1_only(Client::builder())
109        .build()?
110        .get(url)
111        .upgrade()
112        .send()
113        .await?
114        .into_websocket()
115        .await
116}
117
118#[inline]
119#[cfg(not(target_arch = "wasm32"))]
120fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
121    builder.http1_only()
122}
123
124#[inline]
125#[cfg(target_arch = "wasm32")]
126fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
127    builder
128}
129
130/// Trait that extends [`reqwest::RequestBuilder`] with an `upgrade` method.
131pub trait RequestBuilderExt {
132    /// Upgrades the [`RequestBuilder`] to perform a `WebSocket` handshake.
133    ///
134    /// This returns a wrapped type, so you must do this after you set up
135    /// your request, and just before sending the request.
136    fn upgrade(self) -> UpgradedRequestBuilder;
137}
138
139impl RequestBuilderExt for RequestBuilder {
140    fn upgrade(self) -> UpgradedRequestBuilder {
141        UpgradedRequestBuilder::new(self)
142    }
143}
144
145/// Wrapper for a [`reqwest::RequestBuilder`] that performs the
146/// `WebSocket` handshake when sent.
147pub struct UpgradedRequestBuilder {
148    inner: RequestBuilder,
149    protocols: Vec<String>,
150    #[cfg(not(target_arch = "wasm32"))]
151    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
152}
153
154impl UpgradedRequestBuilder {
155    pub(crate) fn new(inner: RequestBuilder) -> Self {
156        Self {
157            inner,
158            protocols: vec![],
159            #[cfg(not(target_arch = "wasm32"))]
160            web_socket_config: None,
161        }
162    }
163
164    /// Selects which sub-protocols are accepted by the client.
165    pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
166        self.protocols = protocols.into_iter().map(Into::into).collect();
167
168        self
169    }
170
171    /// Sets the WebSocket configuration.
172    #[cfg(not(target_arch = "wasm32"))]
173    pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
174        self.web_socket_config = Some(config);
175        self
176    }
177
178    /// Sends the request and returns an [`UpgradeResponse`].
179    pub async fn send(self) -> Result<UpgradeResponse, Error> {
180        #[cfg(not(target_arch = "wasm32"))]
181        let inner = native::send_request(self.inner, &self.protocols).await?;
182
183        #[cfg(target_arch = "wasm32")]
184        let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
185
186        Ok(UpgradeResponse {
187            inner,
188            protocols: self.protocols,
189            #[cfg(not(target_arch = "wasm32"))]
190            web_socket_config: self.web_socket_config,
191        })
192    }
193}
194
195/// The server's response to the `WebSocket` upgrade request.
196///
197/// On non-wasm platforms, this implements `Deref<Target = Response>`, so you
198/// can access all the usual information from the [`reqwest::Response`].
199pub struct UpgradeResponse {
200    #[cfg(not(target_arch = "wasm32"))]
201    inner: native::WebSocketResponse,
202
203    #[cfg(target_arch = "wasm32")]
204    inner: wasm::WebSysWebSocketStream,
205
206    #[allow(dead_code)]
207    protocols: Vec<String>,
208
209    #[cfg(not(target_arch = "wasm32"))]
210    #[allow(dead_code)]
211    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
212}
213
214#[cfg(not(target_arch = "wasm32"))]
215impl std::ops::Deref for UpgradeResponse {
216    type Target = reqwest::Response;
217
218    fn deref(&self) -> &Self::Target {
219        &self.inner.response
220    }
221}
222
223impl UpgradeResponse {
224    /// Turns the response into a `WebSocket`.
225    /// This checks if the `WebSocket` handshake was successful.
226    pub async fn into_websocket(self) -> Result<WebSocket, Error> {
227        #[cfg(not(target_arch = "wasm32"))]
228        let (inner, protocol) = self
229            .inner
230            .into_stream_and_protocol(self.protocols, self.web_socket_config)
231            .await?;
232
233        #[cfg(target_arch = "wasm32")]
234        let (inner, protocol) = {
235            let protocol = self.inner.protocol();
236            (self.inner, Some(protocol))
237        };
238
239        Ok(WebSocket { inner, protocol })
240    }
241
242    /// Consumes the response and returns the inner [`reqwest::Response`].
243    #[must_use]
244    #[cfg(not(target_arch = "wasm32"))]
245    pub fn into_inner(self) -> reqwest::Response {
246        self.inner.response
247    }
248}
249
250/// A `WebSocket` connection. Implements [`futures_util::Stream`] and
251/// [`futures_util::Sink`].
252#[derive(Debug)]
253pub struct WebSocket {
254    #[cfg(not(target_arch = "wasm32"))]
255    inner: native::WebSocketStream,
256
257    #[cfg(target_arch = "wasm32")]
258    inner: wasm::WebSysWebSocketStream,
259
260    protocol: Option<String>,
261}
262
263impl WebSocket {
264    /// Returns the protocol negotiated during the handshake.
265    pub fn protocol(&self) -> Option<&str> {
266        self.protocol.as_deref()
267    }
268
269    /// Closes the connection with a given code and (optional) reason.
270    ///
271    /// # WASM
272    ///
273    /// On wasm `code` must be [`CloseCode::Normal`], [`CloseCode::Iana(_)`],
274    /// or [`CloseCode::Library(_)`]. Furthermore `reason` must be at most 123
275    /// bytes long. Otherwise the call to [`close`][Self::close] will fail.
276    pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
277        #[cfg(not(target_arch = "wasm32"))]
278        {
279            let mut inner = self.inner;
280            inner
281                .close(Some(tungstenite::protocol::CloseFrame {
282                    code: code.into(),
283                    reason: reason.unwrap_or_default().into(),
284                }))
285                .await?;
286        }
287
288        #[cfg(target_arch = "wasm32")]
289        self.inner.close(code.into(), reason.unwrap_or_default())?;
290
291        Ok(())
292    }
293}
294
295impl Stream for WebSocket {
296    type Item = Result<Message, Error>;
297
298    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
299        match ready!(self.inner.poll_next_unpin(cx)) {
300            None => Poll::Ready(None),
301            Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
302            Some(Ok(message)) => {
303                match message.try_into() {
304                    Ok(message) => Poll::Ready(Some(Ok(message))),
305
306                    #[cfg(target_arch = "wasm32")]
307                    Err(e) => match e {},
308
309                    #[cfg(not(target_arch = "wasm32"))]
310                    Err(e) => {
311                        // this fails only for raw frames (which are not received)
312                        panic!("Received an invalid frame: {e}");
313                    }
314                }
315            }
316        }
317    }
318}
319
320impl Sink<Message> for WebSocket {
321    type Error = Error;
322
323    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
324        self.inner.poll_ready_unpin(cx).map_err(Into::into)
325    }
326
327    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
328        self.inner.start_send_unpin(item.into()).map_err(Into::into)
329    }
330
331    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332        self.inner.poll_flush_unpin(cx).map_err(Into::into)
333    }
334
335    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336        self.inner.poll_close_unpin(cx).map_err(Into::into)
337    }
338}
339
340#[cfg(test)]
341pub mod tests {
342    use futures_util::{SinkExt, TryStreamExt};
343    use reqwest::Client;
344    #[cfg(target_arch = "wasm32")]
345    use wasm_bindgen_test::wasm_bindgen_test;
346
347    #[cfg(target_arch = "wasm32")]
348    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
349
350    use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
351
352    async fn test_websocket(mut websocket: WebSocket) {
353        let text = "Hello, World!";
354        websocket.send(Message::Text(text.into())).await.unwrap();
355
356        while let Some(message) = websocket.try_next().await.unwrap() {
357            match message {
358                Message::Text(s) => {
359                    if s == text {
360                        return;
361                    }
362                }
363                _ => {}
364            }
365        }
366
367        panic!("didn't receive text back");
368    }
369
370    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
371    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
372    async fn test_with_request_builder() {
373        let websocket = Client::default()
374            .get("https://echo.websocket.org/")
375            .upgrade()
376            .send()
377            .await
378            .unwrap()
379            .into_websocket()
380            .await
381            .unwrap();
382
383        test_websocket(websocket).await;
384    }
385
386    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
387    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
388    async fn test_shorthand() {
389        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
390        test_websocket(websocket).await;
391    }
392
393    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
394    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
395    async fn test_with_ws_scheme() {
396        let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
397
398        test_websocket(websocket).await;
399    }
400
401    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
402    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
403    async fn test_close() {
404        let websocket = websocket("https://echo.websocket.org/").await.unwrap();
405        websocket
406            .close(CloseCode::Normal, Some("test"))
407            .await
408            .expect("close returned an error");
409    }
410
411    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
412    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
413    async fn test_send_close_frame() {
414        let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
415        websocket
416            .send(Message::Close {
417                code: CloseCode::Normal,
418                reason: "Can you please reply with a close frame?".into(),
419            })
420            .await
421            .unwrap();
422
423        let mut close_received = false;
424        while let Some(message) = websocket.try_next().await.unwrap() {
425            match message {
426                Message::Close { code, .. } => {
427                    assert_eq!(code, CloseCode::Normal);
428                    close_received = true;
429                }
430                _ => {}
431            }
432        }
433
434        assert!(close_received, "No close frame was received");
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
438    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
439    #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
440    async fn test_with_subprotocol() {
441        let websocket = Client::default()
442            .get("https://echo.websocket.org/")
443            .upgrade()
444            .protocols(["chat"])
445            .send()
446            .await
447            .unwrap()
448            .into_websocket()
449            .await
450            .unwrap();
451
452        assert_eq!(websocket.protocol(), Some("chat"));
453
454        test_websocket(websocket).await;
455    }
456
457    #[test]
458    fn close_code_from_u16() {
459        let byte = 1008u16;
460        assert_eq!(CloseCode::from(byte), CloseCode::Policy);
461    }
462
463    #[test]
464    fn close_code_into_u16() {
465        let text = CloseCode::Away;
466        let byte: u16 = text.into();
467        assert_eq!(byte, 1001u16);
468        assert_eq!(u16::from(text), 1001u16);
469    }
470}