axum/extract/
ws.rs

1//! Handle WebSocket connections.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     extract::ws::{WebSocketUpgrade, WebSocket},
8//!     routing::any,
9//!     response::{IntoResponse, Response},
10//!     Router,
11//! };
12//!
13//! let app = Router::new().route("/ws", any(handler));
14//!
15//! async fn handler(ws: WebSocketUpgrade) -> Response {
16//!     ws.on_upgrade(handle_socket)
17//! }
18//!
19//! async fn handle_socket(mut socket: WebSocket) {
20//!     while let Some(msg) = socket.recv().await {
21//!         let msg = if let Ok(msg) = msg {
22//!             msg
23//!         } else {
24//!             // client disconnected
25//!             return;
26//!         };
27//!
28//!         if socket.send(msg).await.is_err() {
29//!             // client disconnected
30//!             return;
31//!         }
32//!     }
33//! }
34//! # let _: Router = app;
35//! ```
36//!
37//! # Passing data and/or state to an `on_upgrade` callback
38//!
39//! ```
40//! use axum::{
41//!     extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42//!     response::Response,
43//!     routing::any,
44//!     Router,
45//! };
46//!
47//! #[derive(Clone)]
48//! struct AppState {
49//!     // ...
50//! }
51//!
52//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
53//!     ws.on_upgrade(|socket| handle_socket(socket, state))
54//! }
55//!
56//! async fn handle_socket(socket: WebSocket, state: AppState) {
57//!     // ...
58//! }
59//!
60//! let app = Router::new()
61//!     .route("/ws", any(handler))
62//!     .with_state(AppState { /* ... */ });
63//! # let _: Router = app;
64//! ```
65//!
66//! # Read and write concurrently
67//!
68//! If you need to read and write concurrently from a [`WebSocket`] you can use
69//! [`StreamExt::split`]:
70//!
71//! ```rust,no_run
72//! use axum::{Error, extract::ws::{WebSocket, Message}};
73//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
74//!
75//! async fn handle_socket(mut socket: WebSocket) {
76//!     let (mut sender, mut receiver) = socket.split();
77//!
78//!     tokio::spawn(write(sender));
79//!     tokio::spawn(read(receiver));
80//! }
81//!
82//! async fn read(receiver: SplitStream<WebSocket>) {
83//!     // ...
84//! }
85//!
86//! async fn write(sender: SplitSink<WebSocket, Message>) {
87//!     // ...
88//! }
89//! ```
90//!
91//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
92
93use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98    sink::{Sink, SinkExt},
99    stream::{FusedStream, Stream, StreamExt},
100};
101use http::{
102    header::{self, HeaderMap, HeaderName, HeaderValue},
103    request::Parts,
104    Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109    borrow::Cow,
110    future::Future,
111    pin::Pin,
112    task::{ready, Context, Poll},
113};
114use tokio_tungstenite::{
115    tungstenite::{
116        self as ts,
117        protocol::{self, WebSocketConfig},
118    },
119    WebSocketStream,
120};
121
122/// Extractor for establishing WebSocket connections.
123///
124/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
125/// in later versions, `CONNECT` is used instead.
126/// To support both, it should be used with [`any`](crate::routing::any).
127///
128/// See the [module docs](self) for an example.
129///
130/// [`MethodFilter`]: crate::routing::MethodFilter
131#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
132#[must_use]
133pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
134    config: WebSocketConfig,
135    /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
136    protocol: Option<HeaderValue>,
137    /// `None` if HTTP/2+ WebSockets are used.
138    sec_websocket_key: Option<HeaderValue>,
139    on_upgrade: hyper::upgrade::OnUpgrade,
140    on_failed_upgrade: F,
141    sec_websocket_protocol: Option<HeaderValue>,
142}
143
144impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("WebSocketUpgrade")
147            .field("config", &self.config)
148            .field("protocol", &self.protocol)
149            .field("sec_websocket_key", &self.sec_websocket_key)
150            .field("sec_websocket_protocol", &self.sec_websocket_protocol)
151            .finish_non_exhaustive()
152    }
153}
154
155impl<F> WebSocketUpgrade<F> {
156    /// Read buffer capacity. The default value is 128KiB
157    pub fn read_buffer_size(mut self, size: usize) -> Self {
158        self.config.read_buffer_size = size;
159        self
160    }
161
162    /// The target minimum size of the write buffer to reach before writing the data
163    /// to the underlying stream.
164    ///
165    /// The default value is 128 KiB.
166    ///
167    /// If set to `0` each message will be eagerly written to the underlying stream.
168    /// It is often more optimal to allow them to buffer a little, hence the default value.
169    ///
170    /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
171    pub fn write_buffer_size(mut self, size: usize) -> Self {
172        self.config.write_buffer_size = size;
173        self
174    }
175
176    /// The max size of the write buffer in bytes. Setting this can provide backpressure
177    /// in the case the write buffer is filling up due to write errors.
178    ///
179    /// The default value is unlimited.
180    ///
181    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
182    /// when writes to the underlying stream are failing. So the **write buffer can not
183    /// fill up if you are not observing write errors even if not flushing**.
184    ///
185    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
186    /// and probably a little more depending on error handling strategy.
187    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
188        self.config.max_write_buffer_size = max;
189        self
190    }
191
192    /// Set the maximum message size (defaults to 64 megabytes)
193    pub fn max_message_size(mut self, max: usize) -> Self {
194        self.config.max_message_size = Some(max);
195        self
196    }
197
198    /// Set the maximum frame size (defaults to 16 megabytes)
199    pub fn max_frame_size(mut self, max: usize) -> Self {
200        self.config.max_frame_size = Some(max);
201        self
202    }
203
204    /// Allow server to accept unmasked frames (defaults to false)
205    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
206        self.config.accept_unmasked_frames = accept;
207        self
208    }
209
210    /// Set the known protocols.
211    ///
212    /// If the protocol name specified by `Sec-WebSocket-Protocol` header
213    /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
214    /// return the protocol name.
215    ///
216    /// The protocols should be listed in decreasing order of preference: if the client offers
217    /// multiple protocols that the server could support, the server will pick the first one in
218    /// this list.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use axum::{
224    ///     extract::ws::{WebSocketUpgrade, WebSocket},
225    ///     routing::any,
226    ///     response::{IntoResponse, Response},
227    ///     Router,
228    /// };
229    ///
230    /// let app = Router::new().route("/ws", any(handler));
231    ///
232    /// async fn handler(ws: WebSocketUpgrade) -> Response {
233    ///     ws.protocols(["graphql-ws", "graphql-transport-ws"])
234    ///         .on_upgrade(|socket| async {
235    ///             // ...
236    ///         })
237    /// }
238    /// # let _: Router = app;
239    /// ```
240    pub fn protocols<I>(mut self, protocols: I) -> Self
241    where
242        I: IntoIterator,
243        I::Item: Into<Cow<'static, str>>,
244    {
245        if let Some(req_protocols) = self
246            .sec_websocket_protocol
247            .as_ref()
248            .and_then(|p| p.to_str().ok())
249        {
250            self.protocol = protocols
251                .into_iter()
252                // FIXME: This will often allocate a new `String` and so is less efficient than it
253                // could be. But that can't be fixed without breaking changes to the public API.
254                .map(Into::into)
255                .find(|protocol| {
256                    req_protocols
257                        .split(',')
258                        .any(|req_protocol| req_protocol.trim() == protocol)
259                })
260                .map(|protocol| match protocol {
261                    Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
262                    Cow::Borrowed(s) => HeaderValue::from_static(s),
263                });
264        }
265
266        self
267    }
268
269    /// Return the selected WebSocket subprotocol, if one has been chosen.
270    ///
271    /// If [`protocols()`][Self::protocols] has been called and a matching
272    /// protocol has been selected, the return value will be `Some` containing
273    /// said protocol. Otherwise, it will be `None`.
274    pub fn selected_protocol(&self) -> Option<&HeaderValue> {
275        self.protocol.as_ref()
276    }
277
278    /// Provide a callback to call if upgrading the connection fails.
279    ///
280    /// The connection upgrade is performed in a background task. If that fails this callback
281    /// will be called.
282    ///
283    /// By default any errors will be silently ignored.
284    ///
285    /// # Example
286    ///
287    /// ```
288    /// use axum::{
289    ///     extract::{WebSocketUpgrade},
290    ///     response::Response,
291    /// };
292    ///
293    /// async fn handler(ws: WebSocketUpgrade) -> Response {
294    ///     ws.on_failed_upgrade(|error| {
295    ///         report_error(error);
296    ///     })
297    ///     .on_upgrade(|socket| async { /* ... */ })
298    /// }
299    /// #
300    /// # fn report_error(_: axum::Error) {}
301    /// ```
302    pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
303    where
304        C: OnFailedUpgrade,
305    {
306        WebSocketUpgrade {
307            config: self.config,
308            protocol: self.protocol,
309            sec_websocket_key: self.sec_websocket_key,
310            on_upgrade: self.on_upgrade,
311            on_failed_upgrade: callback,
312            sec_websocket_protocol: self.sec_websocket_protocol,
313        }
314    }
315
316    /// Finalize upgrading the connection and call the provided callback with
317    /// the stream.
318    #[must_use = "to set up the WebSocket connection, this response must be returned"]
319    pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
320    where
321        C: FnOnce(WebSocket) -> Fut + Send + 'static,
322        Fut: Future<Output = ()> + Send + 'static,
323        F: OnFailedUpgrade,
324    {
325        let on_upgrade = self.on_upgrade;
326        let config = self.config;
327        let on_failed_upgrade = self.on_failed_upgrade;
328
329        let protocol = self.protocol.clone();
330
331        tokio::spawn(async move {
332            let upgraded = match on_upgrade.await {
333                Ok(upgraded) => upgraded,
334                Err(err) => {
335                    on_failed_upgrade.call(Error::new(err));
336                    return;
337                }
338            };
339            let upgraded = TokioIo::new(upgraded);
340
341            let socket =
342                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
343                    .await;
344            let socket = WebSocket {
345                inner: socket,
346                protocol,
347            };
348            callback(socket).await;
349        });
350
351        let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
352            // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
353
354            #[allow(clippy::declare_interior_mutable_const)]
355            const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
356            #[allow(clippy::declare_interior_mutable_const)]
357            const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
358
359            Response::builder()
360                .status(StatusCode::SWITCHING_PROTOCOLS)
361                .header(header::CONNECTION, UPGRADE)
362                .header(header::UPGRADE, WEBSOCKET)
363                .header(
364                    header::SEC_WEBSOCKET_ACCEPT,
365                    sign(sec_websocket_key.as_bytes()),
366                )
367                .body(Body::empty())
368                .unwrap()
369        } else {
370            // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
371            // with a 2XX with an empty body:
372            // <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
373            Response::new(Body::empty())
374        };
375
376        if let Some(protocol) = self.protocol {
377            response
378                .headers_mut()
379                .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
380        }
381
382        response
383    }
384}
385
386/// What to do when a connection upgrade fails.
387///
388/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
389pub trait OnFailedUpgrade: Send + 'static {
390    /// Call the callback.
391    fn call(self, error: Error);
392}
393
394impl<F> OnFailedUpgrade for F
395where
396    F: FnOnce(Error) + Send + 'static,
397{
398    fn call(self, error: Error) {
399        self(error)
400    }
401}
402
403/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
404///
405/// It simply ignores the error.
406#[non_exhaustive]
407#[derive(Debug)]
408pub struct DefaultOnFailedUpgrade;
409
410impl OnFailedUpgrade for DefaultOnFailedUpgrade {
411    #[inline]
412    fn call(self, _error: Error) {}
413}
414
415impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
416where
417    S: Send + Sync,
418{
419    type Rejection = WebSocketUpgradeRejection;
420
421    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
422        let sec_websocket_key = if parts.version <= Version::HTTP_11 {
423            if parts.method != Method::GET {
424                return Err(MethodNotGet.into());
425            }
426
427            if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
428                return Err(InvalidConnectionHeader.into());
429            }
430
431            if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
432                return Err(InvalidUpgradeHeader.into());
433            }
434
435            Some(
436                parts
437                    .headers
438                    .get(header::SEC_WEBSOCKET_KEY)
439                    .ok_or(WebSocketKeyHeaderMissing)?
440                    .clone(),
441            )
442        } else {
443            if parts.method != Method::CONNECT {
444                return Err(MethodNotConnect.into());
445            }
446
447            // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
448            // with.
449            #[cfg(feature = "http2")]
450            if parts
451                .extensions
452                .get::<hyper::ext::Protocol>()
453                .map_or(true, |p| p.as_str() != "websocket")
454            {
455                return Err(InvalidProtocolPseudoheader.into());
456            }
457
458            None
459        };
460
461        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
462            return Err(InvalidWebSocketVersionHeader.into());
463        }
464
465        let on_upgrade = parts
466            .extensions
467            .remove::<hyper::upgrade::OnUpgrade>()
468            .ok_or(ConnectionNotUpgradable)?;
469
470        let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
471
472        Ok(Self {
473            config: Default::default(),
474            protocol: None,
475            sec_websocket_key,
476            on_upgrade,
477            sec_websocket_protocol,
478            on_failed_upgrade: DefaultOnFailedUpgrade,
479        })
480    }
481}
482
483fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
484    if let Some(header) = headers.get(&key) {
485        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
486    } else {
487        false
488    }
489}
490
491fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
492    let header = if let Some(header) = headers.get(&key) {
493        header
494    } else {
495        return false;
496    };
497
498    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
499        header.to_ascii_lowercase().contains(value)
500    } else {
501        false
502    }
503}
504
505/// A stream of WebSocket messages.
506///
507/// See [the module level documentation](self) for more details.
508#[derive(Debug)]
509pub struct WebSocket {
510    inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
511    protocol: Option<HeaderValue>,
512}
513
514impl WebSocket {
515    /// Receive another message.
516    ///
517    /// Returns `None` if the stream has closed.
518    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
519        self.next().await
520    }
521
522    /// Send a message.
523    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
524        self.inner
525            .send(msg.into_tungstenite())
526            .await
527            .map_err(Error::new)
528    }
529
530    /// Return the selected WebSocket subprotocol, if one has been chosen.
531    pub fn protocol(&self) -> Option<&HeaderValue> {
532        self.protocol.as_ref()
533    }
534}
535
536impl FusedStream for WebSocket {
537    /// Returns true if the websocket has been terminated.
538    fn is_terminated(&self) -> bool {
539        self.inner.is_terminated()
540    }
541}
542
543impl Stream for WebSocket {
544    type Item = Result<Message, Error>;
545
546    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
547        loop {
548            match ready!(self.inner.poll_next_unpin(cx)) {
549                Some(Ok(msg)) => {
550                    if let Some(msg) = Message::from_tungstenite(msg) {
551                        return Poll::Ready(Some(Ok(msg)));
552                    }
553                }
554                Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
555                None => return Poll::Ready(None),
556            }
557        }
558    }
559}
560
561impl Sink<Message> for WebSocket {
562    type Error = Error;
563
564    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
565        Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
566    }
567
568    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
569        Pin::new(&mut self.inner)
570            .start_send(item.into_tungstenite())
571            .map_err(Error::new)
572    }
573
574    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
575        Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
576    }
577
578    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
579        Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
580    }
581}
582
583/// UTF-8 wrapper for [Bytes].
584///
585/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
586#[derive(Debug, Clone, PartialEq, Eq, Default)]
587pub struct Utf8Bytes(ts::Utf8Bytes);
588
589impl Utf8Bytes {
590    /// Creates from a static str.
591    #[inline]
592    #[must_use]
593    pub const fn from_static(str: &'static str) -> Self {
594        Self(ts::Utf8Bytes::from_static(str))
595    }
596
597    /// Returns as a string slice.
598    #[inline]
599    pub fn as_str(&self) -> &str {
600        self.0.as_str()
601    }
602
603    fn into_tungstenite(self) -> ts::Utf8Bytes {
604        self.0
605    }
606}
607
608impl std::ops::Deref for Utf8Bytes {
609    type Target = str;
610
611    /// ```
612    /// /// Example fn that takes a str slice
613    /// fn a(s: &str) {}
614    ///
615    /// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
616    ///
617    /// // auto-deref as arg
618    /// a(&data);
619    ///
620    /// // deref to str methods
621    /// assert_eq!(data.len(), 6);
622    /// ```
623    #[inline]
624    fn deref(&self) -> &Self::Target {
625        self.as_str()
626    }
627}
628
629impl std::fmt::Display for Utf8Bytes {
630    #[inline]
631    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632        f.write_str(self.as_str())
633    }
634}
635
636impl TryFrom<Bytes> for Utf8Bytes {
637    type Error = std::str::Utf8Error;
638
639    #[inline]
640    fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
641        Ok(Self(bytes.try_into()?))
642    }
643}
644
645impl TryFrom<Vec<u8>> for Utf8Bytes {
646    type Error = std::str::Utf8Error;
647
648    #[inline]
649    fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
650        Ok(Self(v.try_into()?))
651    }
652}
653
654impl From<String> for Utf8Bytes {
655    #[inline]
656    fn from(s: String) -> Self {
657        Self(s.into())
658    }
659}
660
661impl From<&str> for Utf8Bytes {
662    #[inline]
663    fn from(s: &str) -> Self {
664        Self(s.into())
665    }
666}
667
668impl From<&String> for Utf8Bytes {
669    #[inline]
670    fn from(s: &String) -> Self {
671        Self(s.into())
672    }
673}
674
675impl From<Utf8Bytes> for Bytes {
676    #[inline]
677    fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
678        bytes.into()
679    }
680}
681
682impl<T> PartialEq<T> for Utf8Bytes
683where
684    for<'a> &'a str: PartialEq<T>,
685{
686    /// ```
687    /// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
688    /// assert_eq!(payload, "foo123");
689    /// assert_eq!(payload, "foo123".to_string());
690    /// assert_eq!(payload, &"foo123".to_string());
691    /// assert_eq!(payload, std::borrow::Cow::from("foo123"));
692    /// ```
693    #[inline]
694    fn eq(&self, other: &T) -> bool {
695        self.as_str() == *other
696    }
697}
698
699/// Status code used to indicate why an endpoint is closing the WebSocket connection.
700pub type CloseCode = u16;
701
702/// A struct representing the close command.
703#[derive(Debug, Clone, Eq, PartialEq)]
704pub struct CloseFrame {
705    /// The reason as a code.
706    pub code: CloseCode,
707    /// The reason as text string.
708    pub reason: Utf8Bytes,
709}
710
711/// A WebSocket message.
712//
713// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
714// Copyright (c) 2017 Alexey Galakhov
715// Copyright (c) 2016 Jason Housley
716//
717// Permission is hereby granted, free of charge, to any person obtaining a copy
718// of this software and associated documentation files (the "Software"), to deal
719// in the Software without restriction, including without limitation the rights
720// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
721// copies of the Software, and to permit persons to whom the Software is
722// furnished to do so, subject to the following conditions:
723//
724// The above copyright notice and this permission notice shall be included in
725// all copies or substantial portions of the Software.
726//
727// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
728// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
729// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
730// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
731// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
732// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
733// THE SOFTWARE.
734#[derive(Debug, Eq, PartialEq, Clone)]
735pub enum Message {
736    /// A text WebSocket message
737    Text(Utf8Bytes),
738    /// A binary WebSocket message
739    Binary(Bytes),
740    /// A ping message with the specified payload
741    ///
742    /// The payload here must have a length less than 125 bytes.
743    ///
744    /// Ping messages will be automatically responded to by the server, so you do not have to worry
745    /// about dealing with them yourself.
746    Ping(Bytes),
747    /// A pong message with the specified payload
748    ///
749    /// The payload here must have a length less than 125 bytes.
750    ///
751    /// Pong messages will be automatically sent to the client if a ping message is received, so
752    /// you do not have to worry about constructing them yourself unless you want to implement a
753    /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
754    Pong(Bytes),
755    /// A close message with the optional close frame.
756    ///
757    /// You may "uncleanly" close a WebSocket connection at any time
758    /// by simply dropping the [`WebSocket`].
759    /// However, you may also use the graceful closing protocol, in which
760    /// 1. peer A sends a close frame, and does not send any further messages;
761    /// 2. peer B responds with a close frame, and does not send any further messages;
762    /// 3. peer A processes the remaining messages sent by peer B, before finally
763    /// 4. both peers close the connection.
764    ///
765    /// After sending a close frame,
766    /// you may still read messages,
767    /// but attempts to send another message will error.
768    /// After receiving a close frame,
769    /// axum will automatically respond with a close frame if necessary
770    /// (you do not have to deal with this yourself).
771    /// Since no further messages will be received,
772    /// you may either do nothing
773    /// or explicitly drop the connection.
774    Close(Option<CloseFrame>),
775}
776
777impl Message {
778    fn into_tungstenite(self) -> ts::Message {
779        match self {
780            Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
781            Self::Binary(binary) => ts::Message::Binary(binary),
782            Self::Ping(ping) => ts::Message::Ping(ping),
783            Self::Pong(pong) => ts::Message::Pong(pong),
784            Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
785                code: ts::protocol::frame::coding::CloseCode::from(close.code),
786                reason: close.reason.into_tungstenite(),
787            })),
788            Self::Close(None) => ts::Message::Close(None),
789        }
790    }
791
792    fn from_tungstenite(message: ts::Message) -> Option<Self> {
793        match message {
794            ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
795            ts::Message::Binary(binary) => Some(Self::Binary(binary)),
796            ts::Message::Ping(ping) => Some(Self::Ping(ping)),
797            ts::Message::Pong(pong) => Some(Self::Pong(pong)),
798            ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
799                code: close.code.into(),
800                reason: Utf8Bytes(close.reason),
801            }))),
802            ts::Message::Close(None) => Some(Self::Close(None)),
803            // we can ignore `Frame` frames as recommended by the tungstenite maintainers
804            // https://github.com/snapview/tungstenite-rs/issues/268
805            ts::Message::Frame(_) => None,
806        }
807    }
808
809    /// Consume the WebSocket and return it as binary data.
810    pub fn into_data(self) -> Bytes {
811        match self {
812            Self::Text(string) => Bytes::from(string),
813            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
814            Self::Close(None) => Bytes::new(),
815            Self::Close(Some(frame)) => Bytes::from(frame.reason),
816        }
817    }
818
819    /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
820    pub fn into_text(self) -> Result<Utf8Bytes, Error> {
821        match self {
822            Self::Text(string) => Ok(string),
823            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
824                Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
825            }
826            Self::Close(None) => Ok(Utf8Bytes::default()),
827            Self::Close(Some(frame)) => Ok(frame.reason),
828        }
829    }
830
831    /// Attempt to get a &str from the WebSocket message,
832    /// this will try to convert binary data to utf8.
833    pub fn to_text(&self) -> Result<&str, Error> {
834        match *self {
835            Self::Text(ref string) => Ok(string.as_str()),
836            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
837                Ok(std::str::from_utf8(data).map_err(Error::new)?)
838            }
839            Self::Close(None) => Ok(""),
840            Self::Close(Some(ref frame)) => Ok(&frame.reason),
841        }
842    }
843
844    /// Create a new text WebSocket message from a stringable.
845    pub fn text<S>(string: S) -> Message
846    where
847        S: Into<Utf8Bytes>,
848    {
849        Message::Text(string.into())
850    }
851
852    /// Create a new binary WebSocket message by converting to `Bytes`.
853    pub fn binary<B>(bin: B) -> Message
854    where
855        B: Into<Bytes>,
856    {
857        Message::Binary(bin.into())
858    }
859}
860
861impl From<String> for Message {
862    fn from(string: String) -> Self {
863        Message::Text(string.into())
864    }
865}
866
867impl<'s> From<&'s str> for Message {
868    fn from(string: &'s str) -> Self {
869        Message::Text(string.into())
870    }
871}
872
873impl<'b> From<&'b [u8]> for Message {
874    fn from(data: &'b [u8]) -> Self {
875        Message::Binary(Bytes::copy_from_slice(data))
876    }
877}
878
879impl From<Bytes> for Message {
880    fn from(data: Bytes) -> Self {
881        Message::Binary(data)
882    }
883}
884
885impl From<Vec<u8>> for Message {
886    fn from(data: Vec<u8>) -> Self {
887        Message::Binary(data.into())
888    }
889}
890
891impl From<Message> for Vec<u8> {
892    fn from(msg: Message) -> Self {
893        msg.into_data().to_vec()
894    }
895}
896
897fn sign(key: &[u8]) -> HeaderValue {
898    use base64::engine::Engine as _;
899
900    let mut sha1 = Sha1::default();
901    sha1.update(key);
902    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
903    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
904    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
905}
906
907pub mod rejection {
908    //! WebSocket specific rejections.
909
910    use axum_core::__composite_rejection as composite_rejection;
911    use axum_core::__define_rejection as define_rejection;
912
913    define_rejection! {
914        #[status = METHOD_NOT_ALLOWED]
915        #[body = "Request method must be `GET`"]
916        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
917        pub struct MethodNotGet;
918    }
919
920    define_rejection! {
921        #[status = METHOD_NOT_ALLOWED]
922        #[body = "Request method must be `CONNECT`"]
923        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
924        pub struct MethodNotConnect;
925    }
926
927    define_rejection! {
928        #[status = BAD_REQUEST]
929        #[body = "Connection header did not include 'upgrade'"]
930        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
931        pub struct InvalidConnectionHeader;
932    }
933
934    define_rejection! {
935        #[status = BAD_REQUEST]
936        #[body = "`Upgrade` header did not include 'websocket'"]
937        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
938        pub struct InvalidUpgradeHeader;
939    }
940
941    define_rejection! {
942        #[status = BAD_REQUEST]
943        #[body = "`:protocol` pseudo-header did not include 'websocket'"]
944        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
945        pub struct InvalidProtocolPseudoheader;
946    }
947
948    define_rejection! {
949        #[status = BAD_REQUEST]
950        #[body = "`Sec-WebSocket-Version` header did not include '13'"]
951        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
952        pub struct InvalidWebSocketVersionHeader;
953    }
954
955    define_rejection! {
956        #[status = BAD_REQUEST]
957        #[body = "`Sec-WebSocket-Key` header missing"]
958        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
959        pub struct WebSocketKeyHeaderMissing;
960    }
961
962    define_rejection! {
963        #[status = UPGRADE_REQUIRED]
964        #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
965        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
966        ///
967        /// This rejection is returned if the connection cannot be upgraded for example if the
968        /// request is HTTP/1.0.
969        ///
970        /// See [MDN] for more details about connection upgrades.
971        ///
972        /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
973        pub struct ConnectionNotUpgradable;
974    }
975
976    composite_rejection! {
977        /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
978        ///
979        /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
980        /// extractor can fail.
981        pub enum WebSocketUpgradeRejection {
982            MethodNotGet,
983            MethodNotConnect,
984            InvalidConnectionHeader,
985            InvalidUpgradeHeader,
986            InvalidProtocolPseudoheader,
987            InvalidWebSocketVersionHeader,
988            WebSocketKeyHeaderMissing,
989            ConnectionNotUpgradable,
990        }
991    }
992}
993
994pub mod close_code {
995    //! Constants for [`CloseCode`]s.
996    //!
997    //! [`CloseCode`]: super::CloseCode
998
999    /// Indicates a normal closure, meaning that the purpose for which the connection was
1000    /// established has been fulfilled.
1001    pub const NORMAL: u16 = 1000;
1002
1003    /// Indicates that an endpoint is "going away", such as a server going down or a browser having
1004    /// navigated away from a page.
1005    pub const AWAY: u16 = 1001;
1006
1007    /// Indicates that an endpoint is terminating the connection due to a protocol error.
1008    pub const PROTOCOL: u16 = 1002;
1009
1010    /// Indicates that an endpoint is terminating the connection because it has received a type of
1011    /// data that it cannot accept.
1012    ///
1013    /// For example, an endpoint MAY send this if it understands only text data, but receives a binary message.
1014    pub const UNSUPPORTED: u16 = 1003;
1015
1016    /// Indicates that no status code was included in a closing frame.
1017    pub const STATUS: u16 = 1005;
1018
1019    /// Indicates an abnormal closure.
1020    pub const ABNORMAL: u16 = 1006;
1021
1022    /// Indicates that an endpoint is terminating the connection because it has received data
1023    /// within a message that was not consistent with the type of the message.
1024    ///
1025    /// For example, an endpoint received non-UTF-8 RFC3629 data within a text message.
1026    pub const INVALID: u16 = 1007;
1027
1028    /// Indicates that an endpoint is terminating the connection because it has received a message
1029    /// that violates its policy.
1030    ///
1031    /// This is a generic status code that can be returned when there is
1032    /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
1033    /// hide specific details about the policy.
1034    pub const POLICY: u16 = 1008;
1035
1036    /// Indicates that an endpoint is terminating the connection because it has received a message
1037    /// that is too big for it to process.
1038    pub const SIZE: u16 = 1009;
1039
1040    /// Indicates that an endpoint (client) is terminating the connection because the server
1041    /// did not respond to extension negotiation correctly.
1042    ///
1043    /// Specifically, the client has expected the server to negotiate one or more extension(s),
1044    /// but the server didn't return them in the response message of the WebSocket handshake.
1045    /// The list of extensions that are needed should be given as the reason for closing.
1046    /// Note that this status code is not used by the server,
1047    /// because it can fail the WebSocket handshake instead.
1048    pub const EXTENSION: u16 = 1010;
1049
1050    /// Indicates that a server is terminating the connection because it encountered an unexpected
1051    /// condition that prevented it from fulfilling the request.
1052    pub const ERROR: u16 = 1011;
1053
1054    /// Indicates that the server is restarting.
1055    pub const RESTART: u16 = 1012;
1056
1057    /// Indicates that the server is overloaded and the client should either connect to a different
1058    /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
1059    /// action.
1060    pub const AGAIN: u16 = 1013;
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use std::future::ready;
1066
1067    use super::*;
1068    use crate::{routing::any, test_helpers::spawn_service, Router};
1069    use http::{Request, Version};
1070    use http_body_util::BodyExt as _;
1071    use hyper_util::rt::TokioExecutor;
1072    use tokio::io::{AsyncRead, AsyncWrite};
1073    use tokio::net::TcpStream;
1074    use tokio_tungstenite::tungstenite;
1075    use tower::ServiceExt;
1076
1077    #[crate::test]
1078    async fn rejects_http_1_0_requests() {
1079        let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1080            let rejection = ws.unwrap_err();
1081            assert!(matches!(
1082                rejection,
1083                WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1084            ));
1085            std::future::ready(())
1086        });
1087
1088        let req = Request::builder()
1089            .version(Version::HTTP_10)
1090            .method(Method::GET)
1091            .header("upgrade", "websocket")
1092            .header("connection", "Upgrade")
1093            .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1094            .header("sec-websocket-version", "13")
1095            .body(Body::empty())
1096            .unwrap();
1097
1098        let res = svc.oneshot(req).await.unwrap();
1099
1100        assert_eq!(res.status(), StatusCode::OK);
1101    }
1102
1103    #[allow(dead_code)]
1104    fn default_on_failed_upgrade() {
1105        async fn handler(ws: WebSocketUpgrade) -> Response {
1106            ws.on_upgrade(|_| async {})
1107        }
1108        let _: Router = Router::new().route("/", any(handler));
1109    }
1110
1111    #[allow(dead_code)]
1112    fn on_failed_upgrade() {
1113        async fn handler(ws: WebSocketUpgrade) -> Response {
1114            ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1115                .on_upgrade(|_| async {})
1116        }
1117        let _: Router = Router::new().route("/", any(handler));
1118    }
1119
1120    #[crate::test]
1121    async fn integration_test() {
1122        let addr = spawn_service(echo_app());
1123        let uri = format!("ws://{addr}/echo").try_into().unwrap();
1124        let req = tungstenite::client::ClientRequestBuilder::new(uri)
1125            .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1126        let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1127        test_echo_app(socket, response.headers()).await;
1128    }
1129
1130    #[crate::test]
1131    #[cfg(feature = "http2")]
1132    async fn http2() {
1133        let addr = spawn_service(echo_app());
1134        let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1135        let (mut send_request, conn) =
1136            hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1137                .handshake(io)
1138                .await
1139                .unwrap();
1140
1141        // Wait a little for the SETTINGS frame to go through…
1142        for _ in 0..10 {
1143            tokio::task::yield_now().await;
1144        }
1145        assert!(conn.is_extended_connect_protocol_enabled());
1146        tokio::spawn(async {
1147            conn.await.unwrap();
1148        });
1149
1150        let req = Request::builder()
1151            .method(Method::CONNECT)
1152            .extension(hyper::ext::Protocol::from_static("websocket"))
1153            .uri("/echo")
1154            .header("sec-websocket-version", "13")
1155            .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1156            .header("Host", "server.example.com")
1157            .body(Body::empty())
1158            .unwrap();
1159
1160        let mut response = send_request.send_request(req).await.unwrap();
1161        let status = response.status();
1162        if status != 200 {
1163            let body = response.into_body().collect().await.unwrap().to_bytes();
1164            let body = std::str::from_utf8(&body).unwrap();
1165            panic!("response status was {status}: {body}");
1166        }
1167        let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1168        let upgraded = TokioIo::new(upgraded);
1169        let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1170        test_echo_app(socket, response.headers()).await;
1171    }
1172
1173    fn echo_app() -> Router {
1174        async fn handle_socket(mut socket: WebSocket) {
1175            assert_eq!(socket.protocol().unwrap(), "echo");
1176            while let Some(Ok(msg)) = socket.recv().await {
1177                match msg {
1178                    Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1179                        if socket.send(msg).await.is_err() {
1180                            break;
1181                        }
1182                    }
1183                    Message::Ping(_) | Message::Pong(_) => {
1184                        // tungstenite will respond to pings automatically
1185                    }
1186                }
1187            }
1188        }
1189
1190        Router::new().route(
1191            "/echo",
1192            any(|ws: WebSocketUpgrade| {
1193                let ws = ws.protocols(["echo2", "echo"]);
1194                assert_eq!(ws.selected_protocol().unwrap(), "echo");
1195                ready(ws.on_upgrade(handle_socket))
1196            }),
1197        )
1198    }
1199
1200    const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1201    async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1202        mut socket: WebSocketStream<S>,
1203        headers: &http::HeaderMap,
1204    ) {
1205        assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1206
1207        let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1208        socket.send(input.clone()).await.unwrap();
1209        let output = socket.next().await.unwrap().unwrap();
1210        assert_eq!(input, output);
1211
1212        socket
1213            .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1214            .await
1215            .unwrap();
1216        let output = socket.next().await.unwrap().unwrap();
1217        assert_eq!(
1218            output,
1219            tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1220        );
1221    }
1222}