rquest/client/websocket/
mod.rs

1//! WebSocket Upgrade
2
3#[cfg(feature = "json")]
4mod json;
5mod message;
6
7use std::{
8    borrow::Cow,
9    fmt,
10    net::{IpAddr, Ipv4Addr, Ipv6Addr},
11    ops::{Deref, DerefMut},
12    pin::Pin,
13    task::{Context, Poll, ready},
14};
15
16use crate::{Error, RequestBuilder, Response, error, proxy::IntoProxy};
17use futures_util::{Sink, SinkExt, Stream, StreamExt};
18use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Version, header, uri::Scheme};
19use hyper2::ext::Protocol;
20use serde::Serialize;
21use tokio_tungstenite::tungstenite::{self, protocol};
22use tungstenite::protocol::WebSocketConfig;
23
24pub use message::{CloseCode, CloseFrame, Message, Utf8Bytes};
25
26/// A WebSocket stream.
27pub type WebSocketStream = tokio_tungstenite::WebSocketStream<crate::Upgraded>;
28
29/// Wrapper for [`RequestBuilder`] that performs the
30/// websocket handshake when sent.
31#[derive(Debug)]
32pub struct WebSocketRequestBuilder {
33    inner: RequestBuilder,
34    accept_key: Option<Cow<'static, str>>,
35    protocols: Option<Vec<Cow<'static, str>>>,
36    config: WebSocketConfig,
37}
38
39impl WebSocketRequestBuilder {
40    /// Creates a new WebSocket request builder.
41    pub fn new(inner: RequestBuilder) -> Self {
42        Self {
43            inner,
44            accept_key: None,
45            protocols: None,
46            config: WebSocketConfig::default(),
47        }
48    }
49
50    /// Sets a custom WebSocket accept key.
51    ///
52    /// This method allows you to set a custom WebSocket accept key for the connection.
53    ///
54    /// # Arguments
55    ///
56    /// * `key` - The custom WebSocket accept key to set.
57    ///
58    /// # Returns
59    ///
60    /// * `Self` - The modified instance with the custom WebSocket accept key.
61    pub fn accept_key<K>(mut self, key: K) -> Self
62    where
63        K: Into<Cow<'static, str>>,
64    {
65        self.accept_key = Some(key.into());
66        self
67    }
68
69    /// Sets the websocket subprotocols to request.
70    ///
71    /// This method allows you to specify the subprotocols that the websocket client
72    /// should request during the handshake. Subprotocols are used to define the type
73    /// of communication expected over the websocket connection.
74    ///
75    /// # Arguments
76    ///
77    /// * `protocols` - A list of subprotocols, which can be converted into a `Cow<'static, [String]>`.
78    ///
79    /// # Returns
80    ///
81    /// * `Self` - The modified instance with the updated subprotocols.
82    ///
83    /// # Example
84    ///
85    /// ```
86    /// let request = WebSocketRequestBuilder::new(builder)
87    ///     .protocols(["protocol1", "protocol2"])
88    ///     .build();
89    /// ```
90    pub fn protocols<P>(mut self, protocols: P) -> Self
91    where
92        P: IntoIterator,
93        P::Item: Into<Cow<'static, str>>,
94    {
95        let protocols = protocols.into_iter().map(Into::into).collect();
96        self.protocols = Some(protocols);
97        self
98    }
99
100    /// Sets the websocket max_frame_size configuration.
101    pub fn max_frame_size(mut self, max_frame_size: usize) -> Self {
102        self.config.max_frame_size = Some(max_frame_size);
103        self
104    }
105
106    /// Sets the websocket read_buffer_size configuration.
107    pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
108        self.config.read_buffer_size = read_buffer_size;
109        self
110    }
111
112    /// Sets the websocket write_buffer_size configuration.
113    pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
114        self.config.write_buffer_size = write_buffer_size;
115        self
116    }
117
118    /// Sets the websocket max_write_buffer_size configuration.
119    pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
120        self.config.max_write_buffer_size = max_write_buffer_size;
121        self
122    }
123
124    /// Sets the websocket max_message_size configuration.
125    pub fn max_message_size(mut self, max_message_size: usize) -> Self {
126        self.config.max_message_size = Some(max_message_size);
127        self
128    }
129
130    /// Sets the websocket accept_unmasked_frames configuration.
131    pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
132        self.config.accept_unmasked_frames = accept_unmasked_frames;
133        self
134    }
135
136    /// Configures the WebSocket connection to use HTTP/2.
137    ///
138    /// This method sets the HTTP version to HTTP/2 for the WebSocket connection.
139    /// If the server does not support HTTP/2 WebSocket connections, the connection attempt will fail.
140    ///
141    /// # Returns
142    ///
143    /// * `Self` - The modified instance with the HTTP version set to HTTP/2.
144    pub fn use_http2(mut self) -> Self {
145        self.inner = self.inner.version(Version::HTTP_2);
146        self
147    }
148
149    /// Add a `Header` to this Request.
150    pub fn header<K, V>(mut self, key: K, value: V) -> Self
151    where
152        HeaderName: TryFrom<K>,
153        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
154        HeaderValue: TryFrom<V>,
155        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
156    {
157        self.inner = self.inner.header(key, value);
158        self
159    }
160
161    /// Add a `Header` to append to the request.
162    pub fn header_append<K, V>(mut self, key: K, value: V) -> Self
163    where
164        HeaderName: TryFrom<K>,
165        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
166        HeaderValue: TryFrom<V>,
167        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
168    {
169        self.inner = self.inner.header_append(key, value);
170        self
171    }
172
173    /// Add a set of Headers to the existing ones on this Request.
174    ///
175    /// The headers will be merged in to any already set.
176    pub fn headers(mut self, headers: HeaderMap) -> Self {
177        self.inner = self.inner.headers(headers);
178        self
179    }
180
181    /// Enable HTTP authentication.
182    pub fn auth<V>(mut self, value: V) -> Self
183    where
184        HeaderValue: TryFrom<V>,
185        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
186    {
187        self.inner = self.inner.auth(value);
188        self
189    }
190
191    /// Enable HTTP basic authentication.
192    pub fn basic_auth<U, P>(mut self, username: U, password: Option<P>) -> Self
193    where
194        U: fmt::Display,
195        P: fmt::Display,
196    {
197        self.inner = self.inner.basic_auth(username, password);
198        self
199    }
200
201    /// Enable HTTP bearer authentication.
202    pub fn bearer_auth<T>(mut self, token: T) -> Self
203    where
204        T: fmt::Display,
205    {
206        self.inner = self.inner.bearer_auth(token);
207        self
208    }
209
210    /// Modify the query string of the URL.
211    pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
212        self.inner = self.inner.query(query);
213        self
214    }
215
216    /// Set the proxy for this request.
217    pub fn proxy<U: IntoProxy>(mut self, proxy: U) -> Self {
218        self.inner = self.inner.proxy(proxy);
219        self
220    }
221
222    /// Set the local address for this request.
223    pub fn local_address<V>(mut self, local_address: V) -> Self
224    where
225        V: Into<Option<IpAddr>>,
226    {
227        self.inner = self.inner.local_address(local_address);
228        self
229    }
230
231    /// Set the local addresses for this request.
232    pub fn local_addresses<V4, V6>(mut self, ipv4: V4, ipv6: V6) -> Self
233    where
234        V4: Into<Option<Ipv4Addr>>,
235        V6: Into<Option<Ipv6Addr>>,
236    {
237        self.inner = self.inner.local_addresses(ipv4, ipv6);
238        self
239    }
240
241    /// Set the interface for this request.
242    #[cfg(any(
243        target_os = "android",
244        target_os = "fuchsia",
245        target_os = "linux",
246        all(
247            feature = "apple-network-device-binding",
248            any(
249                target_os = "ios",
250                target_os = "visionos",
251                target_os = "macos",
252                target_os = "tvos",
253                target_os = "watchos",
254            )
255        )
256    ))]
257    #[cfg_attr(docsrs, doc(cfg(feature = "apple-network-device-binding")))]
258    pub fn interface<I>(mut self, interface: I) -> Self
259    where
260        I: Into<std::borrow::Cow<'static, str>>,
261    {
262        self.inner = self.inner.interface(interface);
263        self
264    }
265
266    /// Sends the request and returns and [`WebSocketResponse`].
267    pub async fn send(self) -> Result<WebSocketResponse, Error> {
268        let (client, request) = self.inner.build_split();
269        let mut request = request?;
270
271        // Ensure the scheme is http or https
272        let url = request.url_mut();
273        let new_scheme = match url.scheme() {
274            "ws" => Scheme::HTTP,
275            "wss" => Scheme::HTTPS,
276            _ => {
277                return Err(error::url_bad_scheme(url.clone()));
278            }
279        };
280
281        // Update the scheme
282        url.set_scheme(new_scheme.as_str())
283            .map_err(|_| error::url_bad_scheme(url.clone()))?;
284
285        // Get the version of the request
286        // If the version is not set, use the default version
287        let version = request.version().unwrap_or(Version::HTTP_11);
288
289        // Set the headers for the websocket handshake
290        let headers = request.headers_mut();
291        headers.insert(
292            header::SEC_WEBSOCKET_VERSION,
293            HeaderValue::from_static("13"),
294        );
295
296        // Ensure the request is HTTP 1.1/HTTP 2
297        let accept_key = match version {
298            Version::HTTP_10 | Version::HTTP_11 => {
299                // Generate a nonce if one wasn't provided
300                let nonce = self
301                    .accept_key
302                    .unwrap_or_else(|| Cow::Owned(tungstenite::handshake::client::generate_key()));
303
304                headers.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
305                headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
306                headers.insert(header::SEC_WEBSOCKET_KEY, HeaderValue::from_str(&nonce)?);
307
308                *request.method_mut() = Method::GET;
309                *request.version_mut() = Some(Version::HTTP_11);
310                Some(nonce)
311            }
312            Version::HTTP_2 => {
313                *request.method_mut() = Method::CONNECT;
314                *request.version_mut() = Some(Version::HTTP_2);
315                *request.protocol_mut() = Some(Protocol::from_static("websocket"));
316                None
317            }
318            _ => {
319                return Err(error::upgrade(format!(
320                    "unsupported version: {:?}",
321                    version
322                )));
323            }
324        };
325
326        // Set websocket subprotocols
327        if let Some(ref protocols) = self.protocols {
328            // Sets subprotocols
329            if !protocols.is_empty() {
330                let subprotocols = protocols
331                    .iter()
332                    .map(|s| s.as_ref())
333                    .collect::<Vec<&str>>()
334                    .join(", ");
335
336                request
337                    .headers_mut()
338                    .insert(header::SEC_WEBSOCKET_PROTOCOL, subprotocols.parse()?);
339            }
340        }
341
342        client
343            .execute(request)
344            .await
345            .map(|inner| WebSocketResponse {
346                inner,
347                accept_key,
348                protocols: self.protocols,
349                config: self.config,
350                version,
351            })
352    }
353}
354
355/// The server's response to the websocket upgrade request.
356///
357/// This implements `Deref<Target = Response>`, so you can access all the usual
358/// information from the [`Response`].
359#[derive(Debug)]
360pub struct WebSocketResponse {
361    inner: Response,
362    accept_key: Option<Cow<'static, str>>,
363    protocols: Option<Vec<Cow<'static, str>>>,
364    config: WebSocketConfig,
365    version: Version,
366}
367
368impl Deref for WebSocketResponse {
369    type Target = Response;
370
371    fn deref(&self) -> &Self::Target {
372        &self.inner
373    }
374}
375
376impl DerefMut for WebSocketResponse {
377    fn deref_mut(&mut self) -> &mut Self::Target {
378        &mut self.inner
379    }
380}
381
382impl WebSocketResponse {
383    /// Turns the response into a websocket. This checks if the websocket
384    /// handshake was successful.
385    pub async fn into_websocket(self) -> Result<WebSocket, Error> {
386        let (inner, protocol) = {
387            let status = self.inner.status();
388            let headers = self.inner.headers();
389
390            if !matches!(
391                self.inner.version(),
392                Version::HTTP_10 | Version::HTTP_11 | Version::HTTP_2
393            ) {
394                return Err(error::upgrade(format!(
395                    "unexpected version: {:?}",
396                    self.inner.version()
397                )));
398            }
399
400            match self.version {
401                Version::HTTP_10 | Version::HTTP_11 => {
402                    if status != StatusCode::SWITCHING_PROTOCOLS {
403                        let body = self.inner.text().await?;
404                        return Err(error::upgrade(format!("unexpected status code: {}", body)));
405                    }
406
407                    if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") {
408                        return Err(error::upgrade("missing connection header"));
409                    }
410
411                    if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") {
412                        return Err(error::upgrade("invalid upgrade header"));
413                    }
414
415                    match self
416                        .accept_key
417                        .zip(headers.get(header::SEC_WEBSOCKET_ACCEPT))
418                    {
419                        Some((nonce, header)) => {
420                            if !header.to_str().is_ok_and(|s| {
421                                s == tungstenite::handshake::derive_accept_key(nonce.as_bytes())
422                            }) {
423                                return Err(error::upgrade(format!(
424                                    "invalid accept key: {:?}",
425                                    header
426                                )));
427                            }
428                        }
429                        None => {
430                            return Err(error::upgrade("missing accept key"));
431                        }
432                    }
433                }
434                Version::HTTP_2 => {
435                    if status != StatusCode::OK {
436                        return Err(error::upgrade(format!(
437                            "unexpected status code: {}",
438                            status
439                        )));
440                    }
441                }
442                _ => {
443                    return Err(error::upgrade(format!(
444                        "unsupported version: {:?}",
445                        self.version
446                    )));
447                }
448            }
449
450            let protocol = headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
451
452            match (
453                self.protocols.as_ref().is_none_or(|p| p.is_empty()),
454                &protocol,
455            ) {
456                (true, None) => {
457                    // we didn't request any protocols, so we don't expect one
458                    // in return
459                }
460                (false, None) => {
461                    // server didn't reply with a protocol
462                    return Err(error::upgrade("missing protocol"));
463                }
464                (false, Some(protocol)) => {
465                    if let Some((protocols, protocol)) = self.protocols.zip(protocol.to_str().ok())
466                    {
467                        if !protocols.contains(&Cow::Borrowed(protocol)) {
468                            // the responded protocol is none which we requested
469                            return Err(error::upgrade(format!("invalid protocol: {}", protocol)));
470                        }
471                    } else {
472                        // we didn't request any protocols but got one anyway
473                        return Err(error::upgrade("invalid protocol"));
474                    }
475                }
476                (true, Some(_)) => {
477                    // we didn't request any protocols but got one anyway
478                    return Err(error::upgrade("invalid protocol"));
479                }
480            }
481
482            let upgraded = self.inner.upgrade().await?;
483            let inner = WebSocketStream::from_raw_socket(
484                upgraded,
485                protocol::Role::Client,
486                Some(self.config),
487            )
488            .await;
489
490            (inner, protocol)
491        };
492
493        Ok(WebSocket { inner, protocol })
494    }
495}
496
497/// Checks if the header value is equal to the given value.
498fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
499    if let Some(header) = headers.get(&key) {
500        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
501    } else {
502        false
503    }
504}
505
506/// Checks if the header value contains the given value.
507fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
508    let header = if let Some(header) = headers.get(&key) {
509        header
510    } else {
511        return false;
512    };
513
514    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
515        header.to_ascii_lowercase().contains(value)
516    } else {
517        false
518    }
519}
520
521/// A websocket connection
522#[derive(Debug)]
523pub struct WebSocket {
524    inner: WebSocketStream,
525    protocol: Option<HeaderValue>,
526}
527
528impl WebSocket {
529    /// Receive another message.
530    ///
531    /// Returns `None` if the stream has closed.
532    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
533        self.next().await
534    }
535
536    /// Send a message.
537    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
538        self.inner
539            .send(msg.into_tungstenite())
540            .await
541            .map_err(Into::into)
542    }
543
544    /// Return the selected WebSocket subprotocol, if one has been chosen.
545    pub fn protocol(&self) -> Option<&HeaderValue> {
546        self.protocol.as_ref()
547    }
548
549    /// Closes the connection with a given code and (optional) reason.
550    pub async fn close(self, code: CloseCode, reason: Option<Utf8Bytes>) -> Result<(), Error> {
551        let mut inner = self.inner;
552        inner
553            .close(Some(tungstenite::protocol::CloseFrame {
554                code: code.0.into(),
555                reason: reason
556                    .unwrap_or(Utf8Bytes::from_static("Goodbye"))
557                    .into_tungstenite(),
558            }))
559            .await
560            .map_err(Into::into)
561    }
562}
563
564impl Stream for WebSocket {
565    type Item = Result<Message, Error>;
566
567    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
568        loop {
569            match ready!(self.inner.poll_next_unpin(cx)) {
570                Some(Ok(msg)) => {
571                    if let Some(msg) = Message::from_tungstenite(msg) {
572                        return Poll::Ready(Some(Ok(msg)));
573                    }
574                }
575                Some(Err(err)) => return Poll::Ready(Some(Err(error::body(err)))),
576                None => return Poll::Ready(None),
577            }
578        }
579    }
580}
581
582impl Sink<Message> for WebSocket {
583    type Error = Error;
584
585    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
586        Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
587    }
588
589    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
590        Pin::new(&mut self.inner)
591            .start_send(item.into_tungstenite())
592            .map_err(Into::into)
593    }
594
595    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
596        Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
597    }
598
599    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
600        Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
601    }
602}