axum/extract/
ws.rs

1//! Handle WebSocket connections.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     extract::ws::{WebSocketUpgrade, WebSocket},
8//!     routing::any,
9//!     response::{IntoResponse, Response},
10//!     Router,
11//! };
12//!
13//! let app = Router::new().route("/ws", any(handler));
14//!
15//! async fn handler(ws: WebSocketUpgrade) -> Response {
16//!     ws.on_upgrade(handle_socket)
17//! }
18//!
19//! async fn handle_socket(mut socket: WebSocket) {
20//!     while let Some(msg) = socket.recv().await {
21//!         let msg = if let Ok(msg) = msg {
22//!             msg
23//!         } else {
24//!             // client disconnected
25//!             return;
26//!         };
27//!
28//!         if socket.send(msg).await.is_err() {
29//!             // client disconnected
30//!             return;
31//!         }
32//!     }
33//! }
34//! # let _: Router = app;
35//! ```
36//!
37//! # Passing data and/or state to an `on_upgrade` callback
38//!
39//! ```
40//! use axum::{
41//!     extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42//!     response::Response,
43//!     routing::any,
44//!     Router,
45//! };
46//!
47//! #[derive(Clone)]
48//! struct AppState {
49//!     // ...
50//! }
51//!
52//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
53//!     ws.on_upgrade(|socket| handle_socket(socket, state))
54//! }
55//!
56//! async fn handle_socket(socket: WebSocket, state: AppState) {
57//!     // ...
58//! }
59//!
60//! let app = Router::new()
61//!     .route("/ws", any(handler))
62//!     .with_state(AppState { /* ... */ });
63//! # let _: Router = app;
64//! ```
65//!
66//! # Read and write concurrently
67//!
68//! If you need to read and write concurrently from a [`WebSocket`] you can use
69//! [`StreamExt::split`]:
70//!
71//! ```rust,no_run
72//! use axum::{Error, extract::ws::{WebSocket, Message}};
73//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
74//!
75//! async fn handle_socket(mut socket: WebSocket) {
76//!     let (mut sender, mut receiver) = socket.split();
77//!
78//!     tokio::spawn(write(sender));
79//!     tokio::spawn(read(receiver));
80//! }
81//!
82//! async fn read(receiver: SplitStream<WebSocket>) {
83//!     // ...
84//! }
85//!
86//! async fn write(sender: SplitSink<WebSocket, Message>) {
87//!     // ...
88//! }
89//! ```
90//!
91//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
92
93use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98    sink::{Sink, SinkExt},
99    stream::{Stream, StreamExt},
100};
101use http::{
102    header::{self, HeaderMap, HeaderName, HeaderValue},
103    request::Parts,
104    Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109    borrow::Cow,
110    future::Future,
111    pin::Pin,
112    task::{ready, Context, Poll},
113};
114use tokio_tungstenite::{
115    tungstenite::{
116        self as ts,
117        protocol::{self, WebSocketConfig},
118    },
119    WebSocketStream,
120};
121
122/// Extractor for establishing WebSocket connections.
123///
124/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
125/// in later versions, `CONNECT` is used instead.
126/// To support both, it should be used with [`any`](crate::routing::any).
127///
128/// See the [module docs](self) for an example.
129///
130/// [`MethodFilter`]: crate::routing::MethodFilter
131#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
132pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
133    config: WebSocketConfig,
134    /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
135    protocol: Option<HeaderValue>,
136    /// `None` if HTTP/2+ WebSockets are used.
137    sec_websocket_key: Option<HeaderValue>,
138    on_upgrade: hyper::upgrade::OnUpgrade,
139    on_failed_upgrade: F,
140    sec_websocket_protocol: Option<HeaderValue>,
141}
142
143impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_struct("WebSocketUpgrade")
146            .field("config", &self.config)
147            .field("protocol", &self.protocol)
148            .field("sec_websocket_key", &self.sec_websocket_key)
149            .field("sec_websocket_protocol", &self.sec_websocket_protocol)
150            .finish_non_exhaustive()
151    }
152}
153
154impl<F> WebSocketUpgrade<F> {
155    /// Read buffer capacity. The default value is 128KiB
156    pub fn read_buffer_size(mut self, size: usize) -> Self {
157        self.config.read_buffer_size = size;
158        self
159    }
160
161    /// The target minimum size of the write buffer to reach before writing the data
162    /// to the underlying stream.
163    ///
164    /// The default value is 128 KiB.
165    ///
166    /// If set to `0` each message will be eagerly written to the underlying stream.
167    /// It is often more optimal to allow them to buffer a little, hence the default value.
168    ///
169    /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
170    pub fn write_buffer_size(mut self, size: usize) -> Self {
171        self.config.write_buffer_size = size;
172        self
173    }
174
175    /// The max size of the write buffer in bytes. Setting this can provide backpressure
176    /// in the case the write buffer is filling up due to write errors.
177    ///
178    /// The default value is unlimited.
179    ///
180    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
181    /// when writes to the underlying stream are failing. So the **write buffer can not
182    /// fill up if you are not observing write errors even if not flushing**.
183    ///
184    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
185    /// and probably a little more depending on error handling strategy.
186    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
187        self.config.max_write_buffer_size = max;
188        self
189    }
190
191    /// Set the maximum message size (defaults to 64 megabytes)
192    pub fn max_message_size(mut self, max: usize) -> Self {
193        self.config.max_message_size = Some(max);
194        self
195    }
196
197    /// Set the maximum frame size (defaults to 16 megabytes)
198    pub fn max_frame_size(mut self, max: usize) -> Self {
199        self.config.max_frame_size = Some(max);
200        self
201    }
202
203    /// Allow server to accept unmasked frames (defaults to false)
204    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
205        self.config.accept_unmasked_frames = accept;
206        self
207    }
208
209    /// Set the known protocols.
210    ///
211    /// If the protocol name specified by `Sec-WebSocket-Protocol` header
212    /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
213    /// return the protocol name.
214    ///
215    /// The protocols should be listed in decreasing order of preference: if the client offers
216    /// multiple protocols that the server could support, the server will pick the first one in
217    /// this list.
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use axum::{
223    ///     extract::ws::{WebSocketUpgrade, WebSocket},
224    ///     routing::any,
225    ///     response::{IntoResponse, Response},
226    ///     Router,
227    /// };
228    ///
229    /// let app = Router::new().route("/ws", any(handler));
230    ///
231    /// async fn handler(ws: WebSocketUpgrade) -> Response {
232    ///     ws.protocols(["graphql-ws", "graphql-transport-ws"])
233    ///         .on_upgrade(|socket| async {
234    ///             // ...
235    ///         })
236    /// }
237    /// # let _: Router = app;
238    /// ```
239    pub fn protocols<I>(mut self, protocols: I) -> Self
240    where
241        I: IntoIterator,
242        I::Item: Into<Cow<'static, str>>,
243    {
244        if let Some(req_protocols) = self
245            .sec_websocket_protocol
246            .as_ref()
247            .and_then(|p| p.to_str().ok())
248        {
249            self.protocol = protocols
250                .into_iter()
251                // FIXME: This will often allocate a new `String` and so is less efficient than it
252                // could be. But that can't be fixed without breaking changes to the public API.
253                .map(Into::into)
254                .find(|protocol| {
255                    req_protocols
256                        .split(',')
257                        .any(|req_protocol| req_protocol.trim() == protocol)
258                })
259                .map(|protocol| match protocol {
260                    Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
261                    Cow::Borrowed(s) => HeaderValue::from_static(s),
262                });
263        }
264
265        self
266    }
267
268    /// Return the selected WebSocket subprotocol, if one has been chosen.
269    ///
270    /// If [`protocols()`][Self::protocols] has been called and a matching
271    /// protocol has been selected, the return value will be `Some` containing
272    /// said protocol. Otherwise, it will be `None`.
273    pub fn selected_protocol(&self) -> Option<&HeaderValue> {
274        self.protocol.as_ref()
275    }
276
277    /// Provide a callback to call if upgrading the connection fails.
278    ///
279    /// The connection upgrade is performed in a background task. If that fails this callback
280    /// will be called.
281    ///
282    /// By default any errors will be silently ignored.
283    ///
284    /// # Example
285    ///
286    /// ```
287    /// use axum::{
288    ///     extract::{WebSocketUpgrade},
289    ///     response::Response,
290    /// };
291    ///
292    /// async fn handler(ws: WebSocketUpgrade) -> Response {
293    ///     ws.on_failed_upgrade(|error| {
294    ///         report_error(error);
295    ///     })
296    ///     .on_upgrade(|socket| async { /* ... */ })
297    /// }
298    /// #
299    /// # fn report_error(_: axum::Error) {}
300    /// ```
301    pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
302    where
303        C: OnFailedUpgrade,
304    {
305        WebSocketUpgrade {
306            config: self.config,
307            protocol: self.protocol,
308            sec_websocket_key: self.sec_websocket_key,
309            on_upgrade: self.on_upgrade,
310            on_failed_upgrade: callback,
311            sec_websocket_protocol: self.sec_websocket_protocol,
312        }
313    }
314
315    /// Finalize upgrading the connection and call the provided callback with
316    /// the stream.
317    #[must_use = "to set up the WebSocket connection, this response must be returned"]
318    pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
319    where
320        C: FnOnce(WebSocket) -> Fut + Send + 'static,
321        Fut: Future<Output = ()> + Send + 'static,
322        F: OnFailedUpgrade,
323    {
324        let on_upgrade = self.on_upgrade;
325        let config = self.config;
326        let on_failed_upgrade = self.on_failed_upgrade;
327
328        let protocol = self.protocol.clone();
329
330        tokio::spawn(async move {
331            let upgraded = match on_upgrade.await {
332                Ok(upgraded) => upgraded,
333                Err(err) => {
334                    on_failed_upgrade.call(Error::new(err));
335                    return;
336                }
337            };
338            let upgraded = TokioIo::new(upgraded);
339
340            let socket =
341                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
342                    .await;
343            let socket = WebSocket {
344                inner: socket,
345                protocol,
346            };
347            callback(socket).await;
348        });
349
350        let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
351            // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
352
353            #[allow(clippy::declare_interior_mutable_const)]
354            const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
355            #[allow(clippy::declare_interior_mutable_const)]
356            const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
357
358            Response::builder()
359                .status(StatusCode::SWITCHING_PROTOCOLS)
360                .header(header::CONNECTION, UPGRADE)
361                .header(header::UPGRADE, WEBSOCKET)
362                .header(
363                    header::SEC_WEBSOCKET_ACCEPT,
364                    sign(sec_websocket_key.as_bytes()),
365                )
366                .body(Body::empty())
367                .unwrap()
368        } else {
369            // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
370            // with a 2XX with an empty body:
371            // <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
372            Response::new(Body::empty())
373        };
374
375        if let Some(protocol) = self.protocol {
376            response
377                .headers_mut()
378                .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
379        }
380
381        response
382    }
383}
384
385/// What to do when a connection upgrade fails.
386///
387/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
388pub trait OnFailedUpgrade: Send + 'static {
389    /// Call the callback.
390    fn call(self, error: Error);
391}
392
393impl<F> OnFailedUpgrade for F
394where
395    F: FnOnce(Error) + Send + 'static,
396{
397    fn call(self, error: Error) {
398        self(error)
399    }
400}
401
402/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
403///
404/// It simply ignores the error.
405#[non_exhaustive]
406#[derive(Debug)]
407pub struct DefaultOnFailedUpgrade;
408
409impl OnFailedUpgrade for DefaultOnFailedUpgrade {
410    #[inline]
411    fn call(self, _error: Error) {}
412}
413
414impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
415where
416    S: Send + Sync,
417{
418    type Rejection = WebSocketUpgradeRejection;
419
420    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
421        let sec_websocket_key = if parts.version <= Version::HTTP_11 {
422            if parts.method != Method::GET {
423                return Err(MethodNotGet.into());
424            }
425
426            if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
427                return Err(InvalidConnectionHeader.into());
428            }
429
430            if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
431                return Err(InvalidUpgradeHeader.into());
432            }
433
434            Some(
435                parts
436                    .headers
437                    .get(header::SEC_WEBSOCKET_KEY)
438                    .ok_or(WebSocketKeyHeaderMissing)?
439                    .clone(),
440            )
441        } else {
442            if parts.method != Method::CONNECT {
443                return Err(MethodNotConnect.into());
444            }
445
446            // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
447            // with.
448            #[cfg(feature = "http2")]
449            if parts
450                .extensions
451                .get::<hyper::ext::Protocol>()
452                .map_or(true, |p| p.as_str() != "websocket")
453            {
454                return Err(InvalidProtocolPseudoheader.into());
455            }
456
457            None
458        };
459
460        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
461            return Err(InvalidWebSocketVersionHeader.into());
462        }
463
464        let on_upgrade = parts
465            .extensions
466            .remove::<hyper::upgrade::OnUpgrade>()
467            .ok_or(ConnectionNotUpgradable)?;
468
469        let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
470
471        Ok(Self {
472            config: Default::default(),
473            protocol: None,
474            sec_websocket_key,
475            on_upgrade,
476            sec_websocket_protocol,
477            on_failed_upgrade: DefaultOnFailedUpgrade,
478        })
479    }
480}
481
482fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
483    if let Some(header) = headers.get(&key) {
484        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
485    } else {
486        false
487    }
488}
489
490fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
491    let header = if let Some(header) = headers.get(&key) {
492        header
493    } else {
494        return false;
495    };
496
497    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
498        header.to_ascii_lowercase().contains(value)
499    } else {
500        false
501    }
502}
503
504/// A stream of WebSocket messages.
505///
506/// See [the module level documentation](self) for more details.
507#[derive(Debug)]
508pub struct WebSocket {
509    inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
510    protocol: Option<HeaderValue>,
511}
512
513impl WebSocket {
514    /// Receive another message.
515    ///
516    /// Returns `None` if the stream has closed.
517    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
518        self.next().await
519    }
520
521    /// Send a message.
522    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
523        self.inner
524            .send(msg.into_tungstenite())
525            .await
526            .map_err(Error::new)
527    }
528
529    /// Return the selected WebSocket subprotocol, if one has been chosen.
530    pub fn protocol(&self) -> Option<&HeaderValue> {
531        self.protocol.as_ref()
532    }
533}
534
535impl Stream for WebSocket {
536    type Item = Result<Message, Error>;
537
538    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
539        loop {
540            match ready!(self.inner.poll_next_unpin(cx)) {
541                Some(Ok(msg)) => {
542                    if let Some(msg) = Message::from_tungstenite(msg) {
543                        return Poll::Ready(Some(Ok(msg)));
544                    }
545                }
546                Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
547                None => return Poll::Ready(None),
548            }
549        }
550    }
551}
552
553impl Sink<Message> for WebSocket {
554    type Error = Error;
555
556    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557        Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
558    }
559
560    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
561        Pin::new(&mut self.inner)
562            .start_send(item.into_tungstenite())
563            .map_err(Error::new)
564    }
565
566    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
567        Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
568    }
569
570    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
571        Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
572    }
573}
574
575/// UTF-8 wrapper for [Bytes].
576///
577/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
578#[derive(Debug, Clone, PartialEq, Eq, Default)]
579pub struct Utf8Bytes(ts::Utf8Bytes);
580
581impl Utf8Bytes {
582    /// Creates from a static str.
583    #[inline]
584    pub const fn from_static(str: &'static str) -> Self {
585        Self(ts::Utf8Bytes::from_static(str))
586    }
587
588    /// Returns as a string slice.
589    #[inline]
590    pub fn as_str(&self) -> &str {
591        self.0.as_str()
592    }
593
594    fn into_tungstenite(self) -> ts::Utf8Bytes {
595        self.0
596    }
597}
598
599impl std::ops::Deref for Utf8Bytes {
600    type Target = str;
601
602    /// ```
603    /// /// Example fn that takes a str slice
604    /// fn a(s: &str) {}
605    ///
606    /// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
607    ///
608    /// // auto-deref as arg
609    /// a(&data);
610    ///
611    /// // deref to str methods
612    /// assert_eq!(data.len(), 6);
613    /// ```
614    #[inline]
615    fn deref(&self) -> &Self::Target {
616        self.as_str()
617    }
618}
619
620impl std::fmt::Display for Utf8Bytes {
621    #[inline]
622    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623        f.write_str(self.as_str())
624    }
625}
626
627impl TryFrom<Bytes> for Utf8Bytes {
628    type Error = std::str::Utf8Error;
629
630    #[inline]
631    fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
632        Ok(Self(bytes.try_into()?))
633    }
634}
635
636impl TryFrom<Vec<u8>> for Utf8Bytes {
637    type Error = std::str::Utf8Error;
638
639    #[inline]
640    fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
641        Ok(Self(v.try_into()?))
642    }
643}
644
645impl From<String> for Utf8Bytes {
646    #[inline]
647    fn from(s: String) -> Self {
648        Self(s.into())
649    }
650}
651
652impl From<&str> for Utf8Bytes {
653    #[inline]
654    fn from(s: &str) -> Self {
655        Self(s.into())
656    }
657}
658
659impl From<&String> for Utf8Bytes {
660    #[inline]
661    fn from(s: &String) -> Self {
662        Self(s.into())
663    }
664}
665
666impl From<Utf8Bytes> for Bytes {
667    #[inline]
668    fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
669        bytes.into()
670    }
671}
672
673impl<T> PartialEq<T> for Utf8Bytes
674where
675    for<'a> &'a str: PartialEq<T>,
676{
677    /// ```
678    /// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
679    /// assert_eq!(payload, "foo123");
680    /// assert_eq!(payload, "foo123".to_string());
681    /// assert_eq!(payload, &"foo123".to_string());
682    /// assert_eq!(payload, std::borrow::Cow::from("foo123"));
683    /// ```
684    #[inline]
685    fn eq(&self, other: &T) -> bool {
686        self.as_str() == *other
687    }
688}
689
690/// Status code used to indicate why an endpoint is closing the WebSocket connection.
691pub type CloseCode = u16;
692
693/// A struct representing the close command.
694#[derive(Debug, Clone, Eq, PartialEq)]
695pub struct CloseFrame {
696    /// The reason as a code.
697    pub code: CloseCode,
698    /// The reason as text string.
699    pub reason: Utf8Bytes,
700}
701
702/// A WebSocket message.
703//
704// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
705// Copyright (c) 2017 Alexey Galakhov
706// Copyright (c) 2016 Jason Housley
707//
708// Permission is hereby granted, free of charge, to any person obtaining a copy
709// of this software and associated documentation files (the "Software"), to deal
710// in the Software without restriction, including without limitation the rights
711// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
712// copies of the Software, and to permit persons to whom the Software is
713// furnished to do so, subject to the following conditions:
714//
715// The above copyright notice and this permission notice shall be included in
716// all copies or substantial portions of the Software.
717//
718// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
719// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
720// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
721// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
722// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
723// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
724// THE SOFTWARE.
725#[derive(Debug, Eq, PartialEq, Clone)]
726pub enum Message {
727    /// A text WebSocket message
728    Text(Utf8Bytes),
729    /// A binary WebSocket message
730    Binary(Bytes),
731    /// A ping message with the specified payload
732    ///
733    /// The payload here must have a length less than 125 bytes.
734    ///
735    /// Ping messages will be automatically responded to by the server, so you do not have to worry
736    /// about dealing with them yourself.
737    Ping(Bytes),
738    /// A pong message with the specified payload
739    ///
740    /// The payload here must have a length less than 125 bytes.
741    ///
742    /// Pong messages will be automatically sent to the client if a ping message is received, so
743    /// you do not have to worry about constructing them yourself unless you want to implement a
744    /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
745    Pong(Bytes),
746    /// A close message with the optional close frame.
747    ///
748    /// You may "uncleanly" close a WebSocket connection at any time
749    /// by simply dropping the [`WebSocket`].
750    /// However, you may also use the graceful closing protocol, in which
751    /// 1. peer A sends a close frame, and does not send any further messages;
752    /// 2. peer B responds with a close frame, and does not send any further messages;
753    /// 3. peer A processes the remaining messages sent by peer B, before finally
754    /// 4. both peers close the connection.
755    ///
756    /// After sending a close frame,
757    /// you may still read messages,
758    /// but attempts to send another message will error.
759    /// After receiving a close frame,
760    /// axum will automatically respond with a close frame if necessary
761    /// (you do not have to deal with this yourself).
762    /// Since no further messages will be received,
763    /// you may either do nothing
764    /// or explicitly drop the connection.
765    Close(Option<CloseFrame>),
766}
767
768impl Message {
769    fn into_tungstenite(self) -> ts::Message {
770        match self {
771            Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
772            Self::Binary(binary) => ts::Message::Binary(binary),
773            Self::Ping(ping) => ts::Message::Ping(ping),
774            Self::Pong(pong) => ts::Message::Pong(pong),
775            Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
776                code: ts::protocol::frame::coding::CloseCode::from(close.code),
777                reason: close.reason.into_tungstenite(),
778            })),
779            Self::Close(None) => ts::Message::Close(None),
780        }
781    }
782
783    fn from_tungstenite(message: ts::Message) -> Option<Self> {
784        match message {
785            ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
786            ts::Message::Binary(binary) => Some(Self::Binary(binary)),
787            ts::Message::Ping(ping) => Some(Self::Ping(ping)),
788            ts::Message::Pong(pong) => Some(Self::Pong(pong)),
789            ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
790                code: close.code.into(),
791                reason: Utf8Bytes(close.reason),
792            }))),
793            ts::Message::Close(None) => Some(Self::Close(None)),
794            // we can ignore `Frame` frames as recommended by the tungstenite maintainers
795            // https://github.com/snapview/tungstenite-rs/issues/268
796            ts::Message::Frame(_) => None,
797        }
798    }
799
800    /// Consume the WebSocket and return it as binary data.
801    pub fn into_data(self) -> Bytes {
802        match self {
803            Self::Text(string) => Bytes::from(string),
804            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
805            Self::Close(None) => Bytes::new(),
806            Self::Close(Some(frame)) => Bytes::from(frame.reason),
807        }
808    }
809
810    /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
811    pub fn into_text(self) -> Result<Utf8Bytes, Error> {
812        match self {
813            Self::Text(string) => Ok(string),
814            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
815                Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
816            }
817            Self::Close(None) => Ok(Utf8Bytes::default()),
818            Self::Close(Some(frame)) => Ok(frame.reason),
819        }
820    }
821
822    /// Attempt to get a &str from the WebSocket message,
823    /// this will try to convert binary data to utf8.
824    pub fn to_text(&self) -> Result<&str, Error> {
825        match *self {
826            Self::Text(ref string) => Ok(string.as_str()),
827            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
828                Ok(std::str::from_utf8(data).map_err(Error::new)?)
829            }
830            Self::Close(None) => Ok(""),
831            Self::Close(Some(ref frame)) => Ok(&frame.reason),
832        }
833    }
834
835    /// Create a new text WebSocket message from a stringable.
836    pub fn text<S>(string: S) -> Message
837    where
838        S: Into<Utf8Bytes>,
839    {
840        Message::Text(string.into())
841    }
842
843    /// Create a new binary WebSocket message by converting to `Bytes`.
844    pub fn binary<B>(bin: B) -> Message
845    where
846        B: Into<Bytes>,
847    {
848        Message::Binary(bin.into())
849    }
850}
851
852impl From<String> for Message {
853    fn from(string: String) -> Self {
854        Message::Text(string.into())
855    }
856}
857
858impl<'s> From<&'s str> for Message {
859    fn from(string: &'s str) -> Self {
860        Message::Text(string.into())
861    }
862}
863
864impl<'b> From<&'b [u8]> for Message {
865    fn from(data: &'b [u8]) -> Self {
866        Message::Binary(Bytes::copy_from_slice(data))
867    }
868}
869
870impl From<Bytes> for Message {
871    fn from(data: Bytes) -> Self {
872        Message::Binary(data)
873    }
874}
875
876impl From<Vec<u8>> for Message {
877    fn from(data: Vec<u8>) -> Self {
878        Message::Binary(data.into())
879    }
880}
881
882impl From<Message> for Vec<u8> {
883    fn from(msg: Message) -> Self {
884        msg.into_data().to_vec()
885    }
886}
887
888fn sign(key: &[u8]) -> HeaderValue {
889    use base64::engine::Engine as _;
890
891    let mut sha1 = Sha1::default();
892    sha1.update(key);
893    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
894    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
895    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
896}
897
898pub mod rejection {
899    //! WebSocket specific rejections.
900
901    use axum_core::__composite_rejection as composite_rejection;
902    use axum_core::__define_rejection as define_rejection;
903
904    define_rejection! {
905        #[status = METHOD_NOT_ALLOWED]
906        #[body = "Request method must be `GET`"]
907        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
908        pub struct MethodNotGet;
909    }
910
911    define_rejection! {
912        #[status = METHOD_NOT_ALLOWED]
913        #[body = "Request method must be `CONNECT`"]
914        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
915        pub struct MethodNotConnect;
916    }
917
918    define_rejection! {
919        #[status = BAD_REQUEST]
920        #[body = "Connection header did not include 'upgrade'"]
921        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
922        pub struct InvalidConnectionHeader;
923    }
924
925    define_rejection! {
926        #[status = BAD_REQUEST]
927        #[body = "`Upgrade` header did not include 'websocket'"]
928        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
929        pub struct InvalidUpgradeHeader;
930    }
931
932    define_rejection! {
933        #[status = BAD_REQUEST]
934        #[body = "`:protocol` pseudo-header did not include 'websocket'"]
935        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
936        pub struct InvalidProtocolPseudoheader;
937    }
938
939    define_rejection! {
940        #[status = BAD_REQUEST]
941        #[body = "`Sec-WebSocket-Version` header did not include '13'"]
942        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
943        pub struct InvalidWebSocketVersionHeader;
944    }
945
946    define_rejection! {
947        #[status = BAD_REQUEST]
948        #[body = "`Sec-WebSocket-Key` header missing"]
949        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
950        pub struct WebSocketKeyHeaderMissing;
951    }
952
953    define_rejection! {
954        #[status = UPGRADE_REQUIRED]
955        #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
956        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
957        ///
958        /// This rejection is returned if the connection cannot be upgraded for example if the
959        /// request is HTTP/1.0.
960        ///
961        /// See [MDN] for more details about connection upgrades.
962        ///
963        /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
964        pub struct ConnectionNotUpgradable;
965    }
966
967    composite_rejection! {
968        /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
969        ///
970        /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
971        /// extractor can fail.
972        pub enum WebSocketUpgradeRejection {
973            MethodNotGet,
974            MethodNotConnect,
975            InvalidConnectionHeader,
976            InvalidUpgradeHeader,
977            InvalidProtocolPseudoheader,
978            InvalidWebSocketVersionHeader,
979            WebSocketKeyHeaderMissing,
980            ConnectionNotUpgradable,
981        }
982    }
983}
984
985pub mod close_code {
986    //! Constants for [`CloseCode`]s.
987    //!
988    //! [`CloseCode`]: super::CloseCode
989
990    /// Indicates a normal closure, meaning that the purpose for which the connection was
991    /// established has been fulfilled.
992    pub const NORMAL: u16 = 1000;
993
994    /// Indicates that an endpoint is "going away", such as a server going down or a browser having
995    /// navigated away from a page.
996    pub const AWAY: u16 = 1001;
997
998    /// Indicates that an endpoint is terminating the connection due to a protocol error.
999    pub const PROTOCOL: u16 = 1002;
1000
1001    /// Indicates that an endpoint is terminating the connection because it has received a type of
1002    /// data that it cannot accept.
1003    ///
1004    /// For example, an endpoint MAY send this if it understands only text data, but receives a binary message.
1005    pub const UNSUPPORTED: u16 = 1003;
1006
1007    /// Indicates that no status code was included in a closing frame.
1008    pub const STATUS: u16 = 1005;
1009
1010    /// Indicates an abnormal closure.
1011    pub const ABNORMAL: u16 = 1006;
1012
1013    /// Indicates that an endpoint is terminating the connection because it has received data
1014    /// within a message that was not consistent with the type of the message.
1015    ///
1016    /// For example, an endpoint received non-UTF-8 RFC3629 data within a text message.
1017    pub const INVALID: u16 = 1007;
1018
1019    /// Indicates that an endpoint is terminating the connection because it has received a message
1020    /// that violates its policy.
1021    ///
1022    /// This is a generic status code that can be returned when there is
1023    /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
1024    /// hide specific details about the policy.
1025    pub const POLICY: u16 = 1008;
1026
1027    /// Indicates that an endpoint is terminating the connection because it has received a message
1028    /// that is too big for it to process.
1029    pub const SIZE: u16 = 1009;
1030
1031    /// Indicates that an endpoint (client) is terminating the connection because the server
1032    /// did not respond to extension negotiation correctly.
1033    ///
1034    /// Specifically, the client has expected the server to negotiate one or more extension(s),
1035    /// but the server didn't return them in the response message of the WebSocket handshake.
1036    /// The list of extensions that are needed should be given as the reason for closing.
1037    /// Note that this status code is not used by the server,
1038    /// because it can fail the WebSocket handshake instead.
1039    pub const EXTENSION: u16 = 1010;
1040
1041    /// Indicates that a server is terminating the connection because it encountered an unexpected
1042    /// condition that prevented it from fulfilling the request.
1043    pub const ERROR: u16 = 1011;
1044
1045    /// Indicates that the server is restarting.
1046    pub const RESTART: u16 = 1012;
1047
1048    /// Indicates that the server is overloaded and the client should either connect to a different
1049    /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
1050    /// action.
1051    pub const AGAIN: u16 = 1013;
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use std::future::ready;
1057
1058    use super::*;
1059    use crate::{routing::any, test_helpers::spawn_service, Router};
1060    use http::{Request, Version};
1061    use http_body_util::BodyExt as _;
1062    use hyper_util::rt::TokioExecutor;
1063    use tokio::io::{AsyncRead, AsyncWrite};
1064    use tokio::net::TcpStream;
1065    use tokio_tungstenite::tungstenite;
1066    use tower::ServiceExt;
1067
1068    #[crate::test]
1069    async fn rejects_http_1_0_requests() {
1070        let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1071            let rejection = ws.unwrap_err();
1072            assert!(matches!(
1073                rejection,
1074                WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1075            ));
1076            std::future::ready(())
1077        });
1078
1079        let req = Request::builder()
1080            .version(Version::HTTP_10)
1081            .method(Method::GET)
1082            .header("upgrade", "websocket")
1083            .header("connection", "Upgrade")
1084            .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1085            .header("sec-websocket-version", "13")
1086            .body(Body::empty())
1087            .unwrap();
1088
1089        let res = svc.oneshot(req).await.unwrap();
1090
1091        assert_eq!(res.status(), StatusCode::OK);
1092    }
1093
1094    #[allow(dead_code)]
1095    fn default_on_failed_upgrade() {
1096        async fn handler(ws: WebSocketUpgrade) -> Response {
1097            ws.on_upgrade(|_| async {})
1098        }
1099        let _: Router = Router::new().route("/", any(handler));
1100    }
1101
1102    #[allow(dead_code)]
1103    fn on_failed_upgrade() {
1104        async fn handler(ws: WebSocketUpgrade) -> Response {
1105            ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1106                .on_upgrade(|_| async {})
1107        }
1108        let _: Router = Router::new().route("/", any(handler));
1109    }
1110
1111    #[crate::test]
1112    async fn integration_test() {
1113        let addr = spawn_service(echo_app());
1114        let uri = format!("ws://{addr}/echo").try_into().unwrap();
1115        let req = tungstenite::client::ClientRequestBuilder::new(uri)
1116            .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1117        let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1118        test_echo_app(socket, response.headers()).await;
1119    }
1120
1121    #[crate::test]
1122    #[cfg(feature = "http2")]
1123    async fn http2() {
1124        let addr = spawn_service(echo_app());
1125        let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1126        let (mut send_request, conn) =
1127            hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1128                .handshake(io)
1129                .await
1130                .unwrap();
1131
1132        // Wait a little for the SETTINGS frame to go through…
1133        for _ in 0..10 {
1134            tokio::task::yield_now().await;
1135        }
1136        assert!(conn.is_extended_connect_protocol_enabled());
1137        tokio::spawn(async {
1138            conn.await.unwrap();
1139        });
1140
1141        let req = Request::builder()
1142            .method(Method::CONNECT)
1143            .extension(hyper::ext::Protocol::from_static("websocket"))
1144            .uri("/echo")
1145            .header("sec-websocket-version", "13")
1146            .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1147            .header("Host", "server.example.com")
1148            .body(Body::empty())
1149            .unwrap();
1150
1151        let mut response = send_request.send_request(req).await.unwrap();
1152        let status = response.status();
1153        if status != 200 {
1154            let body = response.into_body().collect().await.unwrap().to_bytes();
1155            let body = std::str::from_utf8(&body).unwrap();
1156            panic!("response status was {status}: {body}");
1157        }
1158        let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1159        let upgraded = TokioIo::new(upgraded);
1160        let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1161        test_echo_app(socket, response.headers()).await;
1162    }
1163
1164    fn echo_app() -> Router {
1165        async fn handle_socket(mut socket: WebSocket) {
1166            assert_eq!(socket.protocol().unwrap(), "echo");
1167            while let Some(Ok(msg)) = socket.recv().await {
1168                match msg {
1169                    Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1170                        if socket.send(msg).await.is_err() {
1171                            break;
1172                        }
1173                    }
1174                    Message::Ping(_) | Message::Pong(_) => {
1175                        // tungstenite will respond to pings automatically
1176                    }
1177                }
1178            }
1179        }
1180
1181        Router::new().route(
1182            "/echo",
1183            any(|ws: WebSocketUpgrade| {
1184                let ws = ws.protocols(["echo2", "echo"]);
1185                assert_eq!(ws.selected_protocol().unwrap(), "echo");
1186                ready(ws.on_upgrade(handle_socket))
1187            }),
1188        )
1189    }
1190
1191    const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1192    async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1193        mut socket: WebSocketStream<S>,
1194        headers: &http::HeaderMap,
1195    ) {
1196        assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1197
1198        let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1199        socket.send(input.clone()).await.unwrap();
1200        let output = socket.next().await.unwrap().unwrap();
1201        assert_eq!(input, output);
1202
1203        socket
1204            .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1205            .await
1206            .unwrap();
1207        let output = socket.next().await.unwrap().unwrap();
1208        assert_eq!(
1209            output,
1210            tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1211        );
1212    }
1213}