Skip to main content

reqwest_websocket/
lib.rs

1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3// note: the tungstenite error variant got bigger, so now clippy complains. but it's only 136 bytes, so I think it's fine. Boxing this would require a breaking change.
4#![allow(clippy::result_large_err)]
5
6//! Provides wrappers for [`reqwest`][2] to enable [`WebSocket`][1] connections.
7//!
8//! # Example
9//!
10//! ```
11//! # use reqwest::Client;
12//! # use reqwest_websocket::{Message, Error};
13//! # use futures_util::{TryStreamExt, SinkExt};
14//! #
15//! # fn main() {
16//! #     // Intentionally ignore the future. We only care that it compiles.
17//! #     let _ = run();
18//! # }
19//! #
20//! # async fn run() -> Result<(), Error> {
21//! // Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
22//! use reqwest_websocket::Upgrade;
23//!
24//! // Creates a GET request, upgrades and sends it.
25//! let response = Client::default()
26//!     .get("wss://echo.websocket.org/")
27//!     .upgrade() // Prepares the WebSocket upgrade.
28//!     .send()
29//!     .await?;
30//!
31//! // Turns the response into a WebSocket stream.
32//! let mut websocket = response.into_websocket().await?;
33//!
34//! // The WebSocket implements `Sink<Message>`.
35//! websocket.send(Message::Text("Hello, World".into())).await?;
36//!
37//! // The WebSocket is also a `TryStream` over `Message`s.
38//! while let Some(message) = websocket.try_next().await? {
39//!     if let Message::Text(text) = message {
40//!         println!("received: {text}")
41//!     }
42//! }
43//! # Ok(())
44//! # }
45//! ```
46//!
47//! [1]: https://en.wikipedia.org/wiki/WebSocket
48//! [2]: https://docs.rs/reqwest/latest/reqwest/index.html
49
50#[cfg(feature = "json")]
51mod json;
52#[cfg(feature = "middleware")]
53mod middleware;
54#[cfg(not(target_arch = "wasm32"))]
55mod native;
56mod protocol;
57#[cfg(target_arch = "wasm32")]
58mod wasm;
59
60use std::{
61    future::Future,
62    pin::Pin,
63    task::{ready, Context, Poll},
64};
65
66#[cfg(not(target_arch = "wasm32"))]
67#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
68pub use crate::native::HandshakeError;
69pub use crate::protocol::{CloseCode, Message};
70pub use bytes::Bytes;
71use futures_util::{Sink, SinkExt, Stream, StreamExt};
72use reqwest::IntoUrl;
73
74/// Errors returned by `reqwest_websocket`.
75#[derive(Debug, thiserror::Error)]
76#[non_exhaustive]
77pub enum Error {
78    #[cfg(not(target_arch = "wasm32"))]
79    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
80    #[error("websocket upgrade failed")]
81    Handshake(#[from] HandshakeError),
82
83    #[error("reqwest error")]
84    Reqwest(#[from] reqwest::Error),
85
86    #[cfg(not(target_arch = "wasm32"))]
87    #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
88    #[error("tungstenite error")]
89    Tungstenite(#[from] tungstenite::Error),
90
91    #[cfg(target_arch = "wasm32")]
92    #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
93    #[error("web_sys error")]
94    WebSys(#[from] wasm::WebSysError),
95
96    /// Error during serialization/deserialization.
97    #[cfg(feature = "json")]
98    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
99    #[error("serde_json error")]
100    Json(#[from] serde_json::Error),
101
102    #[cfg(feature = "middleware")]
103    #[error("reqwest_middleware error")]
104    ReqwestMiddleware(#[from] reqwest_middleware::Error),
105}
106
107/// Opens a `WebSocket` connection at the specified `URL`.
108///
109/// This is a shorthand for creating a [`Request`], sending it, and turning the
110/// [`Response`] into a [`WebSocket`].
111///
112/// [`Request`]: reqwest::Request
113/// [`Response`]: reqwest::Response
114pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
115    builder_http1_only(reqwest::Client::builder())
116        .build()?
117        .get(url)
118        .upgrade()
119        .send()
120        .await?
121        .into_websocket()
122        .await
123}
124
125#[inline]
126#[cfg(not(target_arch = "wasm32"))]
127fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
128    builder.http1_only()
129}
130
131#[inline]
132#[cfg(target_arch = "wasm32")]
133fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
134    builder
135}
136
137/// A generic client.
138///
139/// This is needed by [`RequestBuilder`] to be generic over the specific implementation of a client.
140/// Its only requirement is to be able to execute [`reqwest::Request`]s.
141///
142/// This is implemented for [`reqwest::Client`] and [`reqwest_middleware::ClientWithMiddleware`] (with `middleware` feature).
143/// It provides a single interface for executing a [`reqwest::Request`].
144pub trait Client {
145    fn execute(
146        &self,
147        request: reqwest::Request,
148    ) -> impl Future<Output = Result<reqwest::Response, Error>> + '_;
149}
150
151impl Client for reqwest::Client {
152    async fn execute(&self, request: reqwest::Request) -> Result<reqwest::Response, Error> {
153        self.execute(request).await.map_err(Into::into)
154    }
155}
156
157/// A generic request builder.
158///
159/// This is needed by [`Upgraded`] to be generic over the specific implementation of a request (and client).
160/// Its only requirements are that it provides the specific client type, and can build itself into a client and a [`reqwest::Request`].
161pub trait RequestBuilder {
162    type Client: Client;
163
164    fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>);
165}
166
167impl RequestBuilder for reqwest::RequestBuilder {
168    type Client = reqwest::Client;
169
170    fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>) {
171        let (client, request) = reqwest::RequestBuilder::build_split(self);
172        (client, request.map_err(Into::into))
173    }
174}
175
176/// Extension trait for requests builders that can be upgraded to a websocket connection.
177///
178/// This is automatically implemented for anything that implements our [`RequestBuilder`] trait.
179pub trait Upgrade: Sized {
180    /// Upgrades the [`RequestBuilder`] to perform a `WebSocket` handshake.
181    ///
182    /// This returns a wrapped type, so you must do this after you set up
183    /// your request, and just before sending the request.
184    fn upgrade(self) -> Upgraded<Self>;
185}
186
187impl<R> Upgrade for R
188where
189    R: RequestBuilder,
190{
191    fn upgrade(self) -> Upgraded<Self> {
192        Upgraded::new(self)
193    }
194}
195
196/// Wrapper for a [`reqwest::RequestBuilder`] that performs the
197/// `WebSocket` handshake when sent.
198pub struct Upgraded<R> {
199    inner: R,
200    protocols: Vec<String>,
201    #[cfg(not(target_arch = "wasm32"))]
202    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
203}
204
205impl<R> Upgraded<R>
206where
207    R: RequestBuilder,
208{
209    pub(crate) fn new(inner: R) -> Self {
210        Self {
211            inner,
212            protocols: vec![],
213            #[cfg(not(target_arch = "wasm32"))]
214            web_socket_config: None,
215        }
216    }
217
218    /// Selects which sub-protocols are accepted by the client.
219    pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
220        self.protocols = protocols.into_iter().map(Into::into).collect();
221
222        self
223    }
224
225    /// Sets the WebSocket configuration.
226    #[cfg(not(target_arch = "wasm32"))]
227    pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
228        self.web_socket_config = Some(config);
229        self
230    }
231
232    /// Sends the request and returns an [`UpgradeResponse`].
233    pub async fn send(self) -> Result<UpgradeResponse, Error> {
234        #[cfg(not(target_arch = "wasm32"))]
235        let inner = native::send_request(self.inner, &self.protocols).await?;
236
237        #[cfg(target_arch = "wasm32")]
238        let inner = {
239            let request = self.inner.build_split().1?;
240            wasm::WebSysWebSocketStream::new(request, &self.protocols).await?
241        };
242
243        Ok(UpgradeResponse {
244            inner,
245            protocols: self.protocols,
246            #[cfg(not(target_arch = "wasm32"))]
247            web_socket_config: self.web_socket_config,
248        })
249    }
250}
251
252/// The server's response to the `WebSocket` upgrade request.
253///
254/// On non-wasm platforms, this implements `Deref<Target = Response>`, so you
255/// can access all the usual information from the [`reqwest::Response`].
256pub struct UpgradeResponse {
257    #[cfg(not(target_arch = "wasm32"))]
258    inner: native::WebSocketResponse,
259
260    #[cfg(target_arch = "wasm32")]
261    inner: wasm::WebSysWebSocketStream,
262
263    #[allow(dead_code)]
264    protocols: Vec<String>,
265
266    #[cfg(not(target_arch = "wasm32"))]
267    #[allow(dead_code)]
268    web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
269}
270
271#[cfg(not(target_arch = "wasm32"))]
272impl std::ops::Deref for UpgradeResponse {
273    type Target = reqwest::Response;
274
275    fn deref(&self) -> &Self::Target {
276        &self.inner.response
277    }
278}
279
280impl UpgradeResponse {
281    /// Turns the response into a `WebSocket`.
282    /// This checks if the `WebSocket` handshake was successful.
283    pub async fn into_websocket(self) -> Result<WebSocket, Error> {
284        #[cfg(not(target_arch = "wasm32"))]
285        let (inner, protocol) = self
286            .inner
287            .into_stream_and_protocol(self.protocols, self.web_socket_config)
288            .await?;
289
290        #[cfg(target_arch = "wasm32")]
291        let (inner, protocol) = {
292            let protocol = self.inner.protocol();
293            (self.inner, Some(protocol))
294        };
295
296        Ok(WebSocket { inner, protocol })
297    }
298
299    /// Consumes the response and returns the inner [`reqwest::Response`].
300    #[must_use]
301    #[cfg(not(target_arch = "wasm32"))]
302    pub fn into_inner(self) -> reqwest::Response {
303        self.inner.response
304    }
305}
306
307/// A `WebSocket` connection. Implements [`futures_util::Stream`] and
308/// [`futures_util::Sink`].
309#[derive(Debug)]
310pub struct WebSocket {
311    #[cfg(not(target_arch = "wasm32"))]
312    inner: native::WebSocketStream,
313
314    #[cfg(target_arch = "wasm32")]
315    inner: wasm::WebSysWebSocketStream,
316
317    protocol: Option<String>,
318}
319
320impl WebSocket {
321    /// Returns the protocol negotiated during the handshake.
322    pub fn protocol(&self) -> Option<&str> {
323        self.protocol.as_deref()
324    }
325
326    /// Closes the connection with a given code and (optional) reason.
327    ///
328    /// # WASM
329    ///
330    /// On wasm `code` must be [`CloseCode::Normal`], [`CloseCode::Iana(_)`],
331    /// or [`CloseCode::Library(_)`]. Furthermore `reason` must be at most 123
332    /// bytes long. Otherwise the call to [`close`][Self::close] will fail.
333    pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
334        #[cfg(not(target_arch = "wasm32"))]
335        {
336            let mut inner = self.inner;
337            inner
338                .close(Some(tungstenite::protocol::CloseFrame {
339                    code: code.into(),
340                    reason: reason.unwrap_or_default().into(),
341                }))
342                .await?;
343        }
344
345        #[cfg(target_arch = "wasm32")]
346        self.inner.close(code.into(), reason.unwrap_or_default())?;
347
348        Ok(())
349    }
350}
351
352impl Stream for WebSocket {
353    type Item = Result<Message, Error>;
354
355    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
356        match ready!(self.inner.poll_next_unpin(cx)) {
357            None => Poll::Ready(None),
358            Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
359            Some(Ok(message)) => {
360                match message.try_into() {
361                    Ok(message) => Poll::Ready(Some(Ok(message))),
362
363                    #[cfg(target_arch = "wasm32")]
364                    Err(e) => match e {},
365
366                    #[cfg(not(target_arch = "wasm32"))]
367                    Err(e) => {
368                        // this fails only for raw frames (which are not received)
369                        panic!("Received an invalid frame: {e}");
370                    }
371                }
372            }
373        }
374    }
375}
376
377impl Sink<Message> for WebSocket {
378    type Error = Error;
379
380    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381        self.inner.poll_ready_unpin(cx).map_err(Into::into)
382    }
383
384    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
385        self.inner.start_send_unpin(item.into()).map_err(Into::into)
386    }
387
388    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
389        self.inner.poll_flush_unpin(cx).map_err(Into::into)
390    }
391
392    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
393        self.inner.poll_close_unpin(cx).map_err(Into::into)
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use futures_util::{SinkExt, TryStreamExt};
400    use reqwest::Client;
401    #[cfg(target_arch = "wasm32")]
402    use wasm_bindgen_test::wasm_bindgen_test;
403
404    use crate::{websocket, CloseCode, Message, Upgrade, WebSocket};
405
406    #[cfg(target_arch = "wasm32")]
407    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
408
409    #[cfg(not(target_arch = "wasm32"))]
410    #[derive(Debug)]
411    pub struct TestServer {
412        shutdown_sender: Option<tokio::sync::oneshot::Sender<()>>,
413        http_url: String,
414        ws_url: String,
415    }
416
417    #[cfg(not(target_arch = "wasm32"))]
418    impl TestServer {
419        pub async fn new() -> Self {
420            async fn handle_connection(mut socket: axum::extract::ws::WebSocket) {
421                if let Some(protocol) = socket.protocol() {
422                    if let Ok(protocol) = protocol.to_str() {
423                        println!("server/protocol: {protocol:?}");
424                        if let Err(error) = socket
425                            .send(axum::extract::ws::Message::Text(
426                                format!("protocol: {protocol}").into(),
427                            ))
428                            .await
429                        {
430                            eprintln!("server/send: {error}");
431                            return;
432                        }
433                    } else {
434                        println!("server/protocol: could not convert to utf-8");
435                    }
436                }
437
438                while let Some(message) = socket.recv().await {
439                    match message {
440                        Ok(message) => match &message {
441                            axum::extract::ws::Message::Text(_)
442                            | axum::extract::ws::Message::Binary(_) => {
443                                if let Err(error) = socket.send(message).await {
444                                    eprintln!("server/send: {error}");
445                                    break;
446                                }
447                            }
448                            _ => {}
449                        },
450                        Err(error) => {
451                            eprintln!("server/recv: {error}");
452                            break;
453                        }
454                    }
455                }
456            }
457
458            let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel();
459            let listener = tokio::net::TcpListener::bind(("localhost", 0))
460                .await
461                .unwrap();
462            let port = listener.local_addr().unwrap().port();
463            let app = axum::Router::new().route(
464                "/",
465                axum::routing::any(|ws: axum::extract::ws::WebSocketUpgrade| async move {
466                    ws.protocols(["chat"]).on_upgrade(handle_connection)
467                }),
468            );
469
470            // todo: I think we'll need to spawn this on a proper thread (for which we create a separate runtime) for this to be shared across multiple tests
471            let _join_handle = tokio::spawn(async move {
472                axum::serve(listener, app)
473                    .with_graceful_shutdown(async move {
474                        let _ = shutdown_receiver.await;
475                    })
476                    .await
477                    .unwrap();
478            });
479            Self {
480                shutdown_sender: Some(shutdown_sender),
481                http_url: format!("http://localhost:{port}/"),
482                ws_url: format!("ws://localhost:{port}/"),
483            }
484        }
485
486        pub fn http_url(&self) -> &str {
487            &self.http_url
488        }
489
490        pub fn ws_url(&self) -> &str {
491            &self.ws_url
492        }
493    }
494
495    #[cfg(not(target_arch = "wasm32"))]
496    impl Drop for TestServer {
497        fn drop(&mut self) {
498            if let Some(shutdown_sender) = self.shutdown_sender.take() {
499                println!("Shutting down server");
500                let _ = shutdown_sender.send(());
501            }
502        }
503    }
504
505    #[cfg(target_arch = "wasm32")]
506    pub struct TestServer;
507
508    #[cfg(target_arch = "wasm32")]
509    impl TestServer {
510        pub async fn new() -> Self {
511            Self
512        }
513
514        pub fn http_url(&self) -> &str {
515            "https://echo.websocket.org/"
516        }
517
518        pub fn ws_url(&self) -> &str {
519            "wss://echo.websocket.org/"
520        }
521    }
522
523    pub async fn test_websocket(mut websocket: WebSocket) {
524        let text = "Hello, World!";
525        websocket.send(Message::Text(text.into())).await.unwrap();
526
527        while let Some(message) = websocket.try_next().await.unwrap() {
528            match message {
529                Message::Text(s) => {
530                    if s == text {
531                        return;
532                    }
533                }
534                _ => {}
535            }
536        }
537
538        panic!("didn't receive text back");
539    }
540
541    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
542    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
543    async fn test_with_request_builder() {
544        let echo = TestServer::new().await;
545
546        let websocket = Client::default()
547            .get(echo.http_url())
548            .upgrade()
549            .send()
550            .await
551            .unwrap()
552            .into_websocket()
553            .await
554            .unwrap();
555
556        test_websocket(websocket).await;
557    }
558
559    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
560    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
561    async fn test_shorthand() {
562        let echo = TestServer::new().await;
563
564        let websocket = websocket(echo.http_url()).await.unwrap();
565        test_websocket(websocket).await;
566    }
567
568    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
569    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
570    async fn test_with_ws_scheme() {
571        let echo = TestServer::new().await;
572        let websocket = websocket(echo.ws_url()).await.unwrap();
573
574        test_websocket(websocket).await;
575    }
576
577    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
578    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
579    async fn test_close() {
580        let echo = TestServer::new().await;
581
582        let websocket = websocket(echo.http_url()).await.unwrap();
583        websocket
584            .close(CloseCode::Normal, Some("test"))
585            .await
586            .expect("close returned an error");
587    }
588
589    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
590    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
591    async fn test_send_close_frame() {
592        let echo = TestServer::new().await;
593
594        let mut websocket = websocket(echo.http_url()).await.unwrap();
595        websocket
596            .send(Message::Close {
597                code: CloseCode::Normal,
598                reason: "Can you please reply with a close frame?".into(),
599            })
600            .await
601            .unwrap();
602
603        let mut close_received = false;
604        while let Some(message) = websocket.try_next().await.unwrap() {
605            match message {
606                Message::Close { code, .. } => {
607                    assert_eq!(code, CloseCode::Normal);
608                    close_received = true;
609                }
610                _ => {}
611            }
612        }
613
614        assert!(close_received, "No close frame was received");
615    }
616
617    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
618    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
619    #[cfg_attr(
620        target_arch = "wasm32",
621        ignore = "echo.websocket.org ignores subprotocols"
622    )]
623    async fn test_with_subprotocol() {
624        let echo = TestServer::new().await;
625
626        let mut websocket = Client::default()
627            .get(echo.http_url())
628            .upgrade()
629            .protocols(["chat"])
630            .send()
631            .await
632            .unwrap()
633            .into_websocket()
634            .await
635            .unwrap();
636
637        assert_eq!(websocket.protocol(), Some("chat"));
638
639        let message = websocket.try_next().await.unwrap().unwrap();
640        match message {
641            Message::Text(s) => {
642                assert_eq!(s, "protocol: chat");
643            }
644            _ => {
645                panic!("Expected text message with selected protocol");
646            }
647        }
648    }
649
650    #[test]
651    fn close_code_from_u16() {
652        let byte = 1008u16;
653        assert_eq!(CloseCode::from(byte), CloseCode::Policy);
654    }
655
656    #[test]
657    fn close_code_into_u16() {
658        let text = CloseCode::Away;
659        let byte: u16 = text.into();
660        assert_eq!(byte, 1001u16);
661        assert_eq!(u16::from(text), 1001u16);
662    }
663}