Skip to main content

axum/extract/
ws.rs

1//! Handle WebSocket connections.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     extract::ws::{WebSocketUpgrade, WebSocket},
8//!     routing::any,
9//!     response::{IntoResponse, Response},
10//!     Router,
11//! };
12//!
13//! let app = Router::new().route("/ws", any(handler));
14//!
15//! async fn handler(ws: WebSocketUpgrade) -> Response {
16//!     ws.on_upgrade(handle_socket)
17//! }
18//!
19//! async fn handle_socket(mut socket: WebSocket) {
20//!     while let Some(msg) = socket.recv().await {
21//!         let msg = if let Ok(msg) = msg {
22//!             msg
23//!         } else {
24//!             // client disconnected
25//!             return;
26//!         };
27//!
28//!         if socket.send(msg).await.is_err() {
29//!             // client disconnected
30//!             return;
31//!         }
32//!     }
33//! }
34//! # let _: Router = app;
35//! ```
36//!
37//! # Passing data and/or state to an `on_upgrade` callback
38//!
39//! ```
40//! use axum::{
41//!     extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42//!     response::Response,
43//!     routing::any,
44//!     Router,
45//! };
46//!
47//! #[derive(Clone)]
48//! struct AppState {
49//!     // ...
50//! }
51//!
52//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
53//!     ws.on_upgrade(|socket| handle_socket(socket, state))
54//! }
55//!
56//! async fn handle_socket(socket: WebSocket, state: AppState) {
57//!     // ...
58//! }
59//!
60//! let app = Router::new()
61//!     .route("/ws", any(handler))
62//!     .with_state(AppState { /* ... */ });
63//! # let _: Router = app;
64//! ```
65//!
66//! # Read and write concurrently
67//!
68//! If you need to read and write concurrently from a [`WebSocket`] you can use
69//! [`StreamExt::split`]:
70//!
71//! ```rust,no_run
72//! use axum::{Error, extract::ws::{WebSocket, Message}};
73//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
74//!
75//! async fn handle_socket(mut socket: WebSocket) {
76//!     let (mut sender, mut receiver) = socket.split();
77//!
78//!     tokio::spawn(write(sender));
79//!     tokio::spawn(read(receiver));
80//! }
81//!
82//! async fn read(receiver: SplitStream<WebSocket>) {
83//!     // ...
84//! }
85//!
86//! async fn write(sender: SplitSink<WebSocket, Message>) {
87//!     // ...
88//! }
89//! ```
90//!
91//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
92
93use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98    sink::{Sink, SinkExt},
99    stream::{FusedStream, Stream, StreamExt},
100};
101use http::{
102    header::{self, HeaderMap, HeaderName, HeaderValue},
103    request::Parts,
104    Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109    borrow::Cow,
110    collections::BTreeSet,
111    future::Future,
112    pin::Pin,
113    str,
114    task::{ready, Context, Poll},
115};
116use tokio_tungstenite::{
117    tungstenite::{
118        self as ts,
119        protocol::{self, WebSocketConfig},
120    },
121    WebSocketStream,
122};
123
124/// Extractor for establishing WebSocket connections.
125///
126/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
127/// in later versions, `CONNECT` is used instead.
128/// To support both, it should be used with [`any`](crate::routing::any).
129///
130/// See the [module docs](self) for an example.
131///
132/// [`MethodFilter`]: crate::routing::MethodFilter
133#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
134#[must_use]
135pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
136    config: WebSocketConfig,
137    /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
138    protocol: Option<HeaderValue>,
139    /// `None` if HTTP/2+ WebSockets are used.
140    sec_websocket_key: Option<HeaderValue>,
141    on_upgrade: hyper::upgrade::OnUpgrade,
142    on_failed_upgrade: F,
143    sec_websocket_protocol: BTreeSet<HeaderValue>,
144}
145
146impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        f.debug_struct("WebSocketUpgrade")
149            .field("config", &self.config)
150            .field("protocol", &self.protocol)
151            .field("sec_websocket_key", &self.sec_websocket_key)
152            .field("sec_websocket_protocol", &self.sec_websocket_protocol)
153            .finish_non_exhaustive()
154    }
155}
156
157impl<F> WebSocketUpgrade<F> {
158    /// Read buffer capacity. The default value is 128KiB
159    pub fn read_buffer_size(mut self, size: usize) -> Self {
160        self.config.read_buffer_size = size;
161        self
162    }
163
164    /// The target minimum size of the write buffer to reach before writing the data
165    /// to the underlying stream.
166    ///
167    /// The default value is 128 KiB.
168    ///
169    /// If set to `0` each message will be eagerly written to the underlying stream.
170    /// It is often more optimal to allow them to buffer a little, hence the default value.
171    ///
172    /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
173    pub fn write_buffer_size(mut self, size: usize) -> Self {
174        self.config.write_buffer_size = size;
175        self
176    }
177
178    /// The max size of the write buffer in bytes. Setting this can provide backpressure
179    /// in the case the write buffer is filling up due to write errors.
180    ///
181    /// The default value is unlimited.
182    ///
183    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
184    /// when writes to the underlying stream are failing. So the **write buffer can not
185    /// fill up if you are not observing write errors even if not flushing**.
186    ///
187    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
188    /// and probably a little more depending on error handling strategy.
189    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
190        self.config.max_write_buffer_size = max;
191        self
192    }
193
194    /// Set the maximum message size (defaults to 64 megabytes)
195    pub fn max_message_size(mut self, max: usize) -> Self {
196        self.config.max_message_size = Some(max);
197        self
198    }
199
200    /// Set the maximum frame size (defaults to 16 megabytes)
201    pub fn max_frame_size(mut self, max: usize) -> Self {
202        self.config.max_frame_size = Some(max);
203        self
204    }
205
206    /// Allow server to accept unmasked frames (defaults to false)
207    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
208        self.config.accept_unmasked_frames = accept;
209        self
210    }
211
212    /// Set the known protocols.
213    ///
214    /// If the protocol name specified by `Sec-WebSocket-Protocol` header
215    /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
216    /// return the protocol name.
217    ///
218    /// The protocols should be listed in decreasing order of preference: if the client offers
219    /// multiple protocols that the server could support, the server will pick the first one in
220    /// this list.
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// use axum::{
226    ///     extract::ws::{WebSocketUpgrade, WebSocket},
227    ///     routing::any,
228    ///     response::{IntoResponse, Response},
229    ///     Router,
230    /// };
231    ///
232    /// let app = Router::new().route("/ws", any(handler));
233    ///
234    /// async fn handler(ws: WebSocketUpgrade) -> Response {
235    ///     ws.protocols(["graphql-ws", "graphql-transport-ws"])
236    ///         .on_upgrade(|socket| async {
237    ///             // ...
238    ///         })
239    /// }
240    /// # let _: Router = app;
241    /// ```
242    pub fn protocols<I>(mut self, protocols: I) -> Self
243    where
244        I: IntoIterator,
245        I::Item: Into<Cow<'static, str>>,
246    {
247        self.protocol = protocols
248            .into_iter()
249            .map(Into::into)
250            .find(|proto| {
251                // FIXME: When https://github.com/hyperium/http/pull/814
252                //        is merged + released, we can look use
253                //        `contains(proto.as_bytes())` without converting
254                //        to `HeaderValue` first.
255                let Ok(proto) = HeaderValue::from_str(proto) else {
256                    return false;
257                };
258                self.sec_websocket_protocol.contains(&proto)
259            })
260            .map(|protocol| match protocol {
261                Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
262                Cow::Borrowed(s) => HeaderValue::from_static(s),
263            });
264
265        self
266    }
267
268    /// Return the WebSocket subprotocols requested by the client.
269    ///
270    /// # Examples
271    ///
272    /// If the client sends the following HTTP header in the WebSocket upgrade request:
273    ///
274    /// ```txt
275    /// Sec-WebSocket-Protocol: soap, wamp
276    /// ```
277    ///
278    /// this method returns an iterator yielding `"soap"` and `"wamp"`.
279    pub fn requested_protocols(&self) -> impl Iterator<Item = &HeaderValue> {
280        self.sec_websocket_protocol.iter()
281    }
282
283    /// Set the chosen WebSocket subprotocol.
284    ///
285    /// Another method, [`protocols()`][Self::protocols], also sets the chosen WebSocket
286    /// subprotocol. If both methods are called, only the latter call takes effect.
287    ///
288    /// # Notes
289    ///
290    /// - The chosen protocol is echoed back in the WebSocket upgrade
291    ///   response as required by RFC 6455. Some browsers may reject a
292    ///   value that was not present in the client's request.
293    pub fn set_selected_protocol(&mut self, protocol: HeaderValue) {
294        self.protocol = Some(protocol);
295    }
296
297    /// Return the selected WebSocket subprotocol, if one has been chosen.
298    ///
299    /// If [`protocols()`][Self::protocols] selects a matching protocol, or
300    /// [`set_selected_protocol()`][Self::set_selected_protocol] has been called, the return
301    /// value will be `Some` containing the selected protocol. Otherwise, it will be `None`.
302    pub fn selected_protocol(&self) -> Option<&HeaderValue> {
303        self.protocol.as_ref()
304    }
305
306    /// Provide a callback to call if upgrading the connection fails.
307    ///
308    /// The connection upgrade is performed in a background task. If that fails this callback
309    /// will be called.
310    ///
311    /// By default any errors will be silently ignored.
312    ///
313    /// # Example
314    ///
315    /// ```
316    /// use axum::{
317    ///     extract::{WebSocketUpgrade},
318    ///     response::Response,
319    /// };
320    ///
321    /// async fn handler(ws: WebSocketUpgrade) -> Response {
322    ///     ws.on_failed_upgrade(|error| {
323    ///         report_error(error);
324    ///     })
325    ///     .on_upgrade(|socket| async { /* ... */ })
326    /// }
327    /// #
328    /// # fn report_error(_: axum::Error) {}
329    /// ```
330    pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
331    where
332        C: OnFailedUpgrade,
333    {
334        WebSocketUpgrade {
335            config: self.config,
336            protocol: self.protocol,
337            sec_websocket_key: self.sec_websocket_key,
338            on_upgrade: self.on_upgrade,
339            on_failed_upgrade: callback,
340            sec_websocket_protocol: self.sec_websocket_protocol,
341        }
342    }
343
344    /// Finalize upgrading the connection and call the provided callback with
345    /// the stream.
346    #[must_use = "to set up the WebSocket connection, this response must be returned"]
347    pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
348    where
349        C: FnOnce(WebSocket) -> Fut + Send + 'static,
350        Fut: Future<Output = ()> + Send + 'static,
351        F: OnFailedUpgrade,
352    {
353        let on_upgrade = self.on_upgrade;
354        let config = self.config;
355        let on_failed_upgrade = self.on_failed_upgrade;
356
357        let protocol = self.protocol.clone();
358
359        tokio::spawn(async move {
360            let upgraded = match on_upgrade.await {
361                Ok(upgraded) => upgraded,
362                Err(err) => {
363                    on_failed_upgrade.call(Error::new(err));
364                    return;
365                }
366            };
367            let upgraded = TokioIo::new(upgraded);
368
369            let socket =
370                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
371                    .await;
372            let socket = WebSocket {
373                inner: socket,
374                protocol,
375            };
376            callback(socket).await;
377        });
378
379        let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
380            // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
381
382            #[allow(clippy::declare_interior_mutable_const)]
383            const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
384            #[allow(clippy::declare_interior_mutable_const)]
385            const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
386
387            Response::builder()
388                .status(StatusCode::SWITCHING_PROTOCOLS)
389                .header(header::CONNECTION, UPGRADE)
390                .header(header::UPGRADE, WEBSOCKET)
391                .header(
392                    header::SEC_WEBSOCKET_ACCEPT,
393                    sign(sec_websocket_key.as_bytes()),
394                )
395                .body(Body::empty())
396                .unwrap()
397        } else {
398            // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
399            // with a 2XX with an empty body:
400            // <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
401            Response::new(Body::empty())
402        };
403
404        if let Some(protocol) = self.protocol {
405            response
406                .headers_mut()
407                .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
408        }
409
410        response
411    }
412}
413
414/// What to do when a connection upgrade fails.
415///
416/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
417pub trait OnFailedUpgrade: Send + 'static {
418    /// Call the callback.
419    fn call(self, error: Error);
420}
421
422impl<F> OnFailedUpgrade for F
423where
424    F: FnOnce(Error) + Send + 'static,
425{
426    fn call(self, error: Error) {
427        self(error)
428    }
429}
430
431/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
432///
433/// It simply ignores the error.
434#[non_exhaustive]
435#[derive(Debug)]
436pub struct DefaultOnFailedUpgrade;
437
438impl OnFailedUpgrade for DefaultOnFailedUpgrade {
439    #[inline]
440    fn call(self, _error: Error) {}
441}
442
443impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
444where
445    S: Send + Sync,
446{
447    type Rejection = WebSocketUpgradeRejection;
448
449    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
450        let sec_websocket_key = if parts.version <= Version::HTTP_11 {
451            if parts.method != Method::GET {
452                return Err(MethodNotGet.into());
453            }
454
455            if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
456                return Err(InvalidConnectionHeader.into());
457            }
458
459            if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
460                return Err(InvalidUpgradeHeader.into());
461            }
462
463            Some(
464                parts
465                    .headers
466                    .get(header::SEC_WEBSOCKET_KEY)
467                    .ok_or(WebSocketKeyHeaderMissing)?
468                    .clone(),
469            )
470        } else {
471            if parts.method != Method::CONNECT {
472                return Err(MethodNotConnect.into());
473            }
474
475            // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
476            // with.
477            #[cfg(feature = "http2")]
478            if parts
479                .extensions
480                .get::<hyper::ext::Protocol>()
481                .map_or(true, |p| p.as_str() != "websocket")
482            {
483                return Err(InvalidProtocolPseudoheader.into());
484            }
485
486            None
487        };
488
489        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
490            return Err(InvalidWebSocketVersionHeader.into());
491        }
492
493        let on_upgrade = parts
494            .extensions
495            .remove::<hyper::upgrade::OnUpgrade>()
496            .ok_or(ConnectionNotUpgradable)?;
497
498        let sec_websocket_protocol = parts
499            .headers
500            .get_all(header::SEC_WEBSOCKET_PROTOCOL)
501            .iter()
502            .flat_map(|val| val.as_bytes().split(|&b| b == b','))
503            .map(|proto| {
504                HeaderValue::from_bytes(proto.trim_ascii())
505                    .expect("substring of HeaderValue is valid HeaderValue")
506            })
507            .collect();
508
509        Ok(Self {
510            config: Default::default(),
511            protocol: None,
512            sec_websocket_key,
513            on_upgrade,
514            sec_websocket_protocol,
515            on_failed_upgrade: DefaultOnFailedUpgrade,
516        })
517    }
518}
519
520fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
521    if let Some(header) = headers.get(&key) {
522        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
523    } else {
524        false
525    }
526}
527
528fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
529    let header = if let Some(header) = headers.get(&key) {
530        header
531    } else {
532        return false;
533    };
534
535    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
536        header.to_ascii_lowercase().contains(value)
537    } else {
538        false
539    }
540}
541
542/// A stream of WebSocket messages.
543///
544/// See [the module level documentation](self) for more details.
545#[derive(Debug)]
546pub struct WebSocket {
547    inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
548    protocol: Option<HeaderValue>,
549}
550
551impl WebSocket {
552    /// Receive another message.
553    ///
554    /// Returns `None` if the stream has closed.
555    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
556        self.next().await
557    }
558
559    /// Send a message.
560    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
561        self.inner
562            .send(msg.into_tungstenite())
563            .await
564            .map_err(Error::new)
565    }
566
567    /// Return the selected WebSocket subprotocol, if one has been chosen.
568    pub fn protocol(&self) -> Option<&HeaderValue> {
569        self.protocol.as_ref()
570    }
571}
572
573impl FusedStream for WebSocket {
574    /// Returns true if the websocket has been terminated.
575    fn is_terminated(&self) -> bool {
576        self.inner.is_terminated()
577    }
578}
579
580impl Stream for WebSocket {
581    type Item = Result<Message, Error>;
582
583    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
584        loop {
585            match ready!(self.inner.poll_next_unpin(cx)) {
586                Some(Ok(msg)) => {
587                    if let Some(msg) = Message::from_tungstenite(msg) {
588                        return Poll::Ready(Some(Ok(msg)));
589                    }
590                }
591                Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
592                None => return Poll::Ready(None),
593            }
594        }
595    }
596}
597
598impl Sink<Message> for WebSocket {
599    type Error = Error;
600
601    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
602        Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
603    }
604
605    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
606        Pin::new(&mut self.inner)
607            .start_send(item.into_tungstenite())
608            .map_err(Error::new)
609    }
610
611    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
612        Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
613    }
614
615    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616        Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
617    }
618}
619
620/// UTF-8 wrapper for [Bytes].
621///
622/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
623#[derive(Debug, Clone, PartialEq, Eq, Default)]
624pub struct Utf8Bytes(ts::Utf8Bytes);
625
626impl Utf8Bytes {
627    /// Creates from a static str.
628    #[inline]
629    #[must_use]
630    pub const fn from_static(str: &'static str) -> Self {
631        Self(ts::Utf8Bytes::from_static(str))
632    }
633
634    /// Returns as a string slice.
635    #[inline]
636    pub fn as_str(&self) -> &str {
637        self.0.as_str()
638    }
639
640    fn into_tungstenite(self) -> ts::Utf8Bytes {
641        self.0
642    }
643}
644
645impl std::ops::Deref for Utf8Bytes {
646    type Target = str;
647
648    /// ```
649    /// /// Example fn that takes a str slice
650    /// fn a(s: &str) {}
651    ///
652    /// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
653    ///
654    /// // auto-deref as arg
655    /// a(&data);
656    ///
657    /// // deref to str methods
658    /// assert_eq!(data.len(), 6);
659    /// ```
660    #[inline]
661    fn deref(&self) -> &Self::Target {
662        self.as_str()
663    }
664}
665
666impl std::fmt::Display for Utf8Bytes {
667    #[inline]
668    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
669        f.write_str(self.as_str())
670    }
671}
672
673impl TryFrom<Bytes> for Utf8Bytes {
674    type Error = std::str::Utf8Error;
675
676    #[inline]
677    fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
678        Ok(Self(bytes.try_into()?))
679    }
680}
681
682impl TryFrom<Vec<u8>> for Utf8Bytes {
683    type Error = std::str::Utf8Error;
684
685    #[inline]
686    fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
687        Ok(Self(v.try_into()?))
688    }
689}
690
691impl From<String> for Utf8Bytes {
692    #[inline]
693    fn from(s: String) -> Self {
694        Self(s.into())
695    }
696}
697
698impl From<&str> for Utf8Bytes {
699    #[inline]
700    fn from(s: &str) -> Self {
701        Self(s.into())
702    }
703}
704
705impl From<&String> for Utf8Bytes {
706    #[inline]
707    fn from(s: &String) -> Self {
708        Self(s.into())
709    }
710}
711
712impl From<Utf8Bytes> for Bytes {
713    #[inline]
714    fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
715        bytes.into()
716    }
717}
718
719impl<T> PartialEq<T> for Utf8Bytes
720where
721    for<'a> &'a str: PartialEq<T>,
722{
723    /// ```
724    /// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
725    /// assert_eq!(payload, "foo123");
726    /// assert_eq!(payload, "foo123".to_string());
727    /// assert_eq!(payload, &"foo123".to_string());
728    /// assert_eq!(payload, std::borrow::Cow::from("foo123"));
729    /// ```
730    #[inline]
731    fn eq(&self, other: &T) -> bool {
732        self.as_str() == *other
733    }
734}
735
736/// Status code used to indicate why an endpoint is closing the WebSocket connection.
737pub type CloseCode = u16;
738
739/// A struct representing the close command.
740#[derive(Debug, Clone, Eq, PartialEq)]
741pub struct CloseFrame {
742    /// The reason as a code.
743    pub code: CloseCode,
744    /// The reason as text string.
745    pub reason: Utf8Bytes,
746}
747
748/// A WebSocket message.
749//
750// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
751// Copyright (c) 2017 Alexey Galakhov
752// Copyright (c) 2016 Jason Housley
753//
754// Permission is hereby granted, free of charge, to any person obtaining a copy
755// of this software and associated documentation files (the "Software"), to deal
756// in the Software without restriction, including without limitation the rights
757// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
758// copies of the Software, and to permit persons to whom the Software is
759// furnished to do so, subject to the following conditions:
760//
761// The above copyright notice and this permission notice shall be included in
762// all copies or substantial portions of the Software.
763//
764// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
765// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
766// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
767// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
768// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
769// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
770// THE SOFTWARE.
771#[derive(Debug, Eq, PartialEq, Clone)]
772pub enum Message {
773    /// A text WebSocket message
774    Text(Utf8Bytes),
775    /// A binary WebSocket message
776    Binary(Bytes),
777    /// A ping message with the specified payload
778    ///
779    /// The payload here must have a length less than 125 bytes.
780    ///
781    /// Ping messages will be automatically responded to by the server, so you do not have to worry
782    /// about dealing with them yourself.
783    Ping(Bytes),
784    /// A pong message with the specified payload
785    ///
786    /// The payload here must have a length less than 125 bytes.
787    ///
788    /// Pong messages will be automatically sent to the client if a ping message is received, so
789    /// you do not have to worry about constructing them yourself unless you want to implement a
790    /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
791    Pong(Bytes),
792    /// A close message with the optional close frame.
793    ///
794    /// You may "uncleanly" close a WebSocket connection at any time
795    /// by simply dropping the [`WebSocket`].
796    /// However, you may also use the graceful closing protocol, in which
797    /// 1. peer A sends a close frame, and does not send any further messages;
798    /// 2. peer B responds with a close frame, and does not send any further messages;
799    /// 3. peer A processes the remaining messages sent by peer B, before finally
800    /// 4. both peers close the connection.
801    ///
802    /// After sending a close frame,
803    /// you may still read messages,
804    /// but attempts to send another message will error.
805    /// After receiving a close frame,
806    /// axum will automatically respond with a close frame if necessary
807    /// (you do not have to deal with this yourself).
808    /// Since no further messages will be received,
809    /// you may either do nothing
810    /// or explicitly drop the connection.
811    Close(Option<CloseFrame>),
812}
813
814impl Message {
815    fn into_tungstenite(self) -> ts::Message {
816        match self {
817            Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
818            Self::Binary(binary) => ts::Message::Binary(binary),
819            Self::Ping(ping) => ts::Message::Ping(ping),
820            Self::Pong(pong) => ts::Message::Pong(pong),
821            Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
822                code: ts::protocol::frame::coding::CloseCode::from(close.code),
823                reason: close.reason.into_tungstenite(),
824            })),
825            Self::Close(None) => ts::Message::Close(None),
826        }
827    }
828
829    fn from_tungstenite(message: ts::Message) -> Option<Self> {
830        match message {
831            ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
832            ts::Message::Binary(binary) => Some(Self::Binary(binary)),
833            ts::Message::Ping(ping) => Some(Self::Ping(ping)),
834            ts::Message::Pong(pong) => Some(Self::Pong(pong)),
835            ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
836                code: close.code.into(),
837                reason: Utf8Bytes(close.reason),
838            }))),
839            ts::Message::Close(None) => Some(Self::Close(None)),
840            // we can ignore `Frame` frames as recommended by the tungstenite maintainers
841            // https://github.com/snapview/tungstenite-rs/issues/268
842            ts::Message::Frame(_) => None,
843        }
844    }
845
846    /// Consume the WebSocket and return it as binary data.
847    pub fn into_data(self) -> Bytes {
848        match self {
849            Self::Text(string) => Bytes::from(string),
850            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
851            Self::Close(None) => Bytes::new(),
852            Self::Close(Some(frame)) => Bytes::from(frame.reason),
853        }
854    }
855
856    /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
857    pub fn into_text(self) -> Result<Utf8Bytes, Error> {
858        match self {
859            Self::Text(string) => Ok(string),
860            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
861                Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
862            }
863            Self::Close(None) => Ok(Utf8Bytes::default()),
864            Self::Close(Some(frame)) => Ok(frame.reason),
865        }
866    }
867
868    /// Attempt to get a &str from the WebSocket message,
869    /// this will try to convert binary data to utf8.
870    pub fn to_text(&self) -> Result<&str, Error> {
871        match *self {
872            Self::Text(ref string) => Ok(string.as_str()),
873            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
874                Ok(std::str::from_utf8(data).map_err(Error::new)?)
875            }
876            Self::Close(None) => Ok(""),
877            Self::Close(Some(ref frame)) => Ok(&frame.reason),
878        }
879    }
880
881    /// Create a new text WebSocket message from a stringable.
882    pub fn text<S>(string: S) -> Message
883    where
884        S: Into<Utf8Bytes>,
885    {
886        Message::Text(string.into())
887    }
888
889    /// Create a new binary WebSocket message by converting to `Bytes`.
890    pub fn binary<B>(bin: B) -> Message
891    where
892        B: Into<Bytes>,
893    {
894        Message::Binary(bin.into())
895    }
896}
897
898impl From<String> for Message {
899    fn from(string: String) -> Self {
900        Message::Text(string.into())
901    }
902}
903
904impl<'s> From<&'s str> for Message {
905    fn from(string: &'s str) -> Self {
906        Message::Text(string.into())
907    }
908}
909
910impl<'b> From<&'b [u8]> for Message {
911    fn from(data: &'b [u8]) -> Self {
912        Message::Binary(Bytes::copy_from_slice(data))
913    }
914}
915
916impl From<Bytes> for Message {
917    fn from(data: Bytes) -> Self {
918        Message::Binary(data)
919    }
920}
921
922impl From<Vec<u8>> for Message {
923    fn from(data: Vec<u8>) -> Self {
924        Message::Binary(data.into())
925    }
926}
927
928impl From<Message> for Vec<u8> {
929    fn from(msg: Message) -> Self {
930        msg.into_data().to_vec()
931    }
932}
933
934fn sign(key: &[u8]) -> HeaderValue {
935    use base64::engine::Engine as _;
936
937    let mut sha1 = Sha1::default();
938    sha1.update(key);
939    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
940    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
941    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
942}
943
944pub mod rejection {
945    //! WebSocket specific rejections.
946
947    use axum_core::__composite_rejection as composite_rejection;
948    use axum_core::__define_rejection as define_rejection;
949
950    define_rejection! {
951        #[status = METHOD_NOT_ALLOWED]
952        #[body = "Request method must be `GET`"]
953        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
954        pub struct MethodNotGet;
955    }
956
957    define_rejection! {
958        #[status = METHOD_NOT_ALLOWED]
959        #[body = "Request method must be `CONNECT`"]
960        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
961        pub struct MethodNotConnect;
962    }
963
964    define_rejection! {
965        #[status = BAD_REQUEST]
966        #[body = "Connection header did not include 'upgrade'"]
967        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
968        pub struct InvalidConnectionHeader;
969    }
970
971    define_rejection! {
972        #[status = BAD_REQUEST]
973        #[body = "`Upgrade` header did not include 'websocket'"]
974        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
975        pub struct InvalidUpgradeHeader;
976    }
977
978    define_rejection! {
979        #[status = BAD_REQUEST]
980        #[body = "`:protocol` pseudo-header did not include 'websocket'"]
981        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
982        pub struct InvalidProtocolPseudoheader;
983    }
984
985    define_rejection! {
986        #[status = BAD_REQUEST]
987        #[body = "`Sec-WebSocket-Version` header did not include '13'"]
988        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
989        pub struct InvalidWebSocketVersionHeader;
990    }
991
992    define_rejection! {
993        #[status = BAD_REQUEST]
994        #[body = "`Sec-WebSocket-Key` header missing"]
995        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
996        pub struct WebSocketKeyHeaderMissing;
997    }
998
999    define_rejection! {
1000        #[status = UPGRADE_REQUIRED]
1001        #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
1002        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
1003        ///
1004        /// This rejection is returned if the connection cannot be upgraded for example if the
1005        /// request is HTTP/1.0.
1006        ///
1007        /// See [MDN] for more details about connection upgrades.
1008        ///
1009        /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
1010        pub struct ConnectionNotUpgradable;
1011    }
1012
1013    composite_rejection! {
1014        /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
1015        ///
1016        /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
1017        /// extractor can fail.
1018        pub enum WebSocketUpgradeRejection {
1019            MethodNotGet,
1020            MethodNotConnect,
1021            InvalidConnectionHeader,
1022            InvalidUpgradeHeader,
1023            InvalidProtocolPseudoheader,
1024            InvalidWebSocketVersionHeader,
1025            WebSocketKeyHeaderMissing,
1026            ConnectionNotUpgradable,
1027        }
1028    }
1029}
1030
1031pub mod close_code {
1032    //! Constants for [`CloseCode`]s.
1033    //!
1034    //! [`CloseCode`]: super::CloseCode
1035
1036    /// Indicates a normal closure, meaning that the purpose for which the connection was
1037    /// established has been fulfilled.
1038    pub const NORMAL: u16 = 1000;
1039
1040    /// Indicates that an endpoint is "going away", such as a server going down or a browser having
1041    /// navigated away from a page.
1042    pub const AWAY: u16 = 1001;
1043
1044    /// Indicates that an endpoint is terminating the connection due to a protocol error.
1045    pub const PROTOCOL: u16 = 1002;
1046
1047    /// Indicates that an endpoint is terminating the connection because it has received a type of
1048    /// data that it cannot accept.
1049    ///
1050    /// For example, an endpoint MAY send this if it understands only text data, but receives a binary message.
1051    pub const UNSUPPORTED: u16 = 1003;
1052
1053    /// Indicates that no status code was included in a closing frame.
1054    pub const STATUS: u16 = 1005;
1055
1056    /// Indicates an abnormal closure.
1057    pub const ABNORMAL: u16 = 1006;
1058
1059    /// Indicates that an endpoint is terminating the connection because it has received data
1060    /// within a message that was not consistent with the type of the message.
1061    ///
1062    /// For example, an endpoint received non-UTF-8 RFC3629 data within a text message.
1063    pub const INVALID: u16 = 1007;
1064
1065    /// Indicates that an endpoint is terminating the connection because it has received a message
1066    /// that violates its policy.
1067    ///
1068    /// This is a generic status code that can be returned when there is
1069    /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
1070    /// hide specific details about the policy.
1071    pub const POLICY: u16 = 1008;
1072
1073    /// Indicates that an endpoint is terminating the connection because it has received a message
1074    /// that is too big for it to process.
1075    pub const SIZE: u16 = 1009;
1076
1077    /// Indicates that an endpoint (client) is terminating the connection because the server
1078    /// did not respond to extension negotiation correctly.
1079    ///
1080    /// Specifically, the client has expected the server to negotiate one or more extension(s),
1081    /// but the server didn't return them in the response message of the WebSocket handshake.
1082    /// The list of extensions that are needed should be given as the reason for closing.
1083    /// Note that this status code is not used by the server,
1084    /// because it can fail the WebSocket handshake instead.
1085    pub const EXTENSION: u16 = 1010;
1086
1087    /// Indicates that a server is terminating the connection because it encountered an unexpected
1088    /// condition that prevented it from fulfilling the request.
1089    pub const ERROR: u16 = 1011;
1090
1091    /// Indicates that the server is restarting.
1092    pub const RESTART: u16 = 1012;
1093
1094    /// Indicates that the server is overloaded and the client should either connect to a different
1095    /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
1096    /// action.
1097    pub const AGAIN: u16 = 1013;
1098}
1099
1100#[cfg(test)]
1101mod tests {
1102    use std::future::ready;
1103
1104    use super::*;
1105    use crate::{routing::any, test_helpers::spawn_service, Router};
1106    use http::{Request, Version};
1107    use http_body_util::BodyExt as _;
1108    use hyper_util::rt::TokioExecutor;
1109    use tokio::io::{AsyncRead, AsyncWrite};
1110    use tokio::net::TcpStream;
1111    use tokio_tungstenite::tungstenite;
1112    use tower::ServiceExt;
1113
1114    #[crate::test]
1115    async fn rejects_http_1_0_requests() {
1116        let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1117            let rejection = ws.unwrap_err();
1118            assert!(matches!(
1119                rejection,
1120                WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1121            ));
1122            std::future::ready(())
1123        });
1124
1125        let req = Request::builder()
1126            .version(Version::HTTP_10)
1127            .method(Method::GET)
1128            .header("upgrade", "websocket")
1129            .header("connection", "Upgrade")
1130            .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1131            .header("sec-websocket-version", "13")
1132            .body(Body::empty())
1133            .unwrap();
1134
1135        let res = svc.oneshot(req).await.unwrap();
1136
1137        assert_eq!(res.status(), StatusCode::OK);
1138    }
1139
1140    #[allow(dead_code)]
1141    fn default_on_failed_upgrade() {
1142        async fn handler(ws: WebSocketUpgrade) -> Response {
1143            ws.on_upgrade(|_| async {})
1144        }
1145        let _: Router = Router::new().route("/", any(handler));
1146    }
1147
1148    #[allow(dead_code)]
1149    fn on_failed_upgrade() {
1150        async fn handler(ws: WebSocketUpgrade) -> Response {
1151            ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1152                .on_upgrade(|_| async {})
1153        }
1154        let _: Router = Router::new().route("/", any(handler));
1155    }
1156
1157    #[crate::test]
1158    async fn integration_test() {
1159        let addr = spawn_service(echo_app());
1160        let uri = format!("ws://{addr}/echo").try_into().unwrap();
1161        let req = tungstenite::client::ClientRequestBuilder::new(uri)
1162            .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1163        let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1164        test_echo_app(socket, response.headers()).await;
1165    }
1166
1167    #[crate::test]
1168    #[cfg(feature = "http2")]
1169    async fn http2() {
1170        let addr = spawn_service(echo_app());
1171        let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1172        let (mut send_request, conn) =
1173            hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1174                .handshake(io)
1175                .await
1176                .unwrap();
1177
1178        // Wait a little for the SETTINGS frame to go through…
1179        for _ in 0..10 {
1180            tokio::task::yield_now().await;
1181        }
1182        assert!(conn.is_extended_connect_protocol_enabled());
1183        tokio::spawn(async {
1184            conn.await.unwrap();
1185        });
1186
1187        let req = Request::builder()
1188            .method(Method::CONNECT)
1189            .extension(hyper::ext::Protocol::from_static("websocket"))
1190            .uri("/echo")
1191            .header("sec-websocket-version", "13")
1192            .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1193            .header("Host", "server.example.com")
1194            .body(Body::empty())
1195            .unwrap();
1196
1197        let mut response = send_request.send_request(req).await.unwrap();
1198        let status = response.status();
1199        if status != 200 {
1200            let body = response.into_body().collect().await.unwrap().to_bytes();
1201            let body = std::str::from_utf8(&body).unwrap();
1202            panic!("response status was {status}: {body}");
1203        }
1204        let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1205        let upgraded = TokioIo::new(upgraded);
1206        let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1207        test_echo_app(socket, response.headers()).await;
1208    }
1209
1210    fn echo_app() -> Router {
1211        async fn handle_socket(mut socket: WebSocket) {
1212            assert_eq!(socket.protocol().unwrap(), "echo");
1213            while let Some(Ok(msg)) = socket.recv().await {
1214                match msg {
1215                    Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1216                        if socket.send(msg).await.is_err() {
1217                            break;
1218                        }
1219                    }
1220                    Message::Ping(_) | Message::Pong(_) => {
1221                        // tungstenite will respond to pings automatically
1222                    }
1223                }
1224            }
1225        }
1226
1227        Router::new().route(
1228            "/echo",
1229            any(|ws: WebSocketUpgrade| {
1230                let ws = ws.protocols(["echo2", "echo"]);
1231                assert_eq!(ws.selected_protocol().unwrap(), "echo");
1232                ready(ws.on_upgrade(handle_socket))
1233            }),
1234        )
1235    }
1236
1237    const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1238    async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1239        mut socket: WebSocketStream<S>,
1240        headers: &http::HeaderMap,
1241    ) {
1242        assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1243
1244        let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1245        socket.send(input.clone()).await.unwrap();
1246        let output = socket.next().await.unwrap().unwrap();
1247        assert_eq!(input, output);
1248
1249        socket
1250            .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1251            .await
1252            .unwrap();
1253        let output = socket.next().await.unwrap().unwrap();
1254        assert_eq!(
1255            output,
1256            tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1257        );
1258    }
1259}