axum_tungstenite/
lib.rs

1//! WebSocket connections for [axum] directly using [tungstenite].
2//!
3//! # Differences from `axum::extract::ws`
4//!
5//! axum already supports WebSockets through [`axum::extract::ws`]. However the fact that axum uses
6//! tungstenite under the hood is a private implementation detail. Thus axum doesn't directly
7//! expose types from tungstenite, such as [`tungstenite::Error`] and [`tungstenite::Message`].
8//! This allows axum to update to a new major version of tungstenite in a new minor version of
9//! axum, which leads to greater API stability.
10//!
11//! This library works differently as it directly uses the types from tungstenite in its public
12//! API. That makes some things simpler but also means axum-tungstenite will receive a new major
13//! version when tungstenite does.
14//!
15//! # Which should you choose?
16//!
17//! By default you should use `axum::extract::ws` unless you specifically need something from
18//! tungstenite and don't mind keeping up with additional breaking changes.
19//!
20//! # Example
21//!
22//! ```
23//! use axum::{
24//!     routing::get,
25//!     response::IntoResponse,
26//!     Router,
27//! };
28//! use axum_tungstenite::{WebSocketUpgrade, WebSocket};
29//!
30//! let app = Router::new().route("/ws", get(handler));
31//!
32//! async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
33//!     ws.on_upgrade(handle_socket)
34//! }
35//!
36//! async fn handle_socket(mut socket: WebSocket) {
37//!     while let Some(msg) = socket.recv().await {
38//!         let msg = if let Ok(msg) = msg {
39//!             msg
40//!         } else {
41//!             // client disconnected
42//!             return;
43//!         };
44//!
45//!         if socket.send(msg).await.is_err() {
46//!             // client disconnected
47//!             return;
48//!         }
49//!     }
50//! }
51//! # async {
52//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
53//! # };
54//! ```
55//!
56//! [axum]: https://crates.io/crates/axum
57//! [tungstenite]: https://crates.io/crates/tungstenite
58//! [`axum::extract::ws`]: https://docs.rs/axum/latest/axum/extract/ws/index.html
59//! [`tungstenite::Error`]: https://docs.rs/tungstenite/latest/tungstenite/error/enum.Error.html
60//! [`tungstenite::Message`]: https://docs.rs/tungstenite/latest/tungstenite/enum.Message.html
61
62#![warn(
63    clippy::all,
64    clippy::dbg_macro,
65    clippy::todo,
66    clippy::empty_enum,
67    clippy::enum_glob_use,
68    clippy::mem_forget,
69    clippy::unused_self,
70    clippy::filter_map_next,
71    clippy::needless_continue,
72    clippy::needless_borrow,
73    clippy::match_wildcard_for_single_variants,
74    clippy::if_let_mutex,
75    clippy::mismatched_target_os,
76    clippy::await_holding_lock,
77    clippy::match_on_vec_items,
78    clippy::imprecise_flops,
79    clippy::suboptimal_flops,
80    clippy::lossy_float_literal,
81    clippy::rest_pat_in_fully_bound_structs,
82    clippy::fn_params_excessive_bools,
83    clippy::exit,
84    clippy::inefficient_to_string,
85    clippy::linkedlist,
86    clippy::macro_use_imports,
87    clippy::option_option,
88    clippy::verbose_file_reads,
89    clippy::unnested_or_patterns,
90    clippy::str_to_string,
91    rust_2018_idioms,
92    future_incompatible,
93    nonstandard_style,
94    missing_debug_implementations,
95    missing_docs
96)]
97#![deny(unreachable_pub, private_in_public)]
98#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
99#![forbid(unsafe_code)]
100#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
101#![cfg_attr(test, allow(clippy::float_cmp))]
102
103use self::rejection::*;
104use async_trait::async_trait;
105use axum_core::{
106    extract::FromRequestParts,
107    response::{IntoResponse, Response},
108};
109use bytes::Bytes;
110use futures_util::{
111    sink::{Sink, SinkExt},
112    stream::{Stream, StreamExt},
113};
114use http::{
115    header::{self, HeaderMap, HeaderName, HeaderValue},
116    request::Parts,
117    Method, StatusCode,
118};
119use hyper::upgrade::{OnUpgrade, Upgraded};
120use sha1::{Digest, Sha1};
121use std::{
122    borrow::Cow,
123    future::Future,
124    pin::Pin,
125    task::{Context, Poll},
126};
127use tokio_tungstenite::{
128    tungstenite::protocol::{self, WebSocketConfig},
129    WebSocketStream,
130};
131
132#[doc(no_inline)]
133pub use tokio_tungstenite::tungstenite::error::{
134    CapacityError, Error, ProtocolError, TlsError, UrlError,
135};
136#[doc(no_inline)]
137pub use tokio_tungstenite::tungstenite::Message;
138
139/// Extractor for establishing WebSocket connections.
140///
141/// See the [module docs](self) for an example.
142#[derive(Debug)]
143pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
144    config: WebSocketConfig,
145    /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
146    protocol: Option<HeaderValue>,
147    sec_websocket_key: HeaderValue,
148    on_upgrade: OnUpgrade,
149    on_failed_upgrade: F,
150    sec_websocket_protocol: Option<HeaderValue>,
151}
152
153impl<C> WebSocketUpgrade<C> {
154    /// The target minimum size of the write buffer to reach before writing the data
155    /// to the underlying stream.
156    ///
157    /// The default value is 128 KiB.
158    ///
159    /// If set to `0` each message will be eagerly written to the underlying stream.
160    /// It is often more optimal to allow them to buffer a little, hence the default value.
161    ///
162    /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
163    pub fn write_buffer_size(mut self, size: usize) -> Self {
164        self.config.write_buffer_size = size;
165        self
166    }
167
168    /// The max size of the write buffer in bytes. Setting this can provide backpressure
169    /// in the case the write buffer is filling up due to write errors.
170    ///
171    /// The default value is unlimited.
172    ///
173    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
174    /// when writes to the underlying stream are failing. So the **write buffer can not
175    /// fill up if you are not observing write errors even if not flushing**.
176    ///
177    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
178    /// and probably a little more depending on error handling strategy.
179    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
180        self.config.max_write_buffer_size = max;
181        self
182    }
183
184    /// Set the maximum message size (defaults to 64 megabytes)
185    pub fn max_message_size(mut self, max: usize) -> Self {
186        self.config.max_message_size = Some(max);
187        self
188    }
189
190    /// Set the maximum frame size (defaults to 16 megabytes)
191    pub fn max_frame_size(mut self, max: usize) -> Self {
192        self.config.max_frame_size = Some(max);
193        self
194    }
195
196    /// Allow server to accept unmasked frames (defaults to false)
197    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
198        self.config.accept_unmasked_frames = accept;
199        self
200    }
201
202    /// Set the known protocols.
203    ///
204    /// If the protocol name specified by `Sec-WebSocket-Protocol` header
205    /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
206    /// return the protocol name.
207    ///
208    /// The protocols should be listed in decreasing order of preference: if the client offers
209    /// multiple protocols that the server could support, the server will pick the first one in
210    /// this list.
211    pub fn protocols<I>(mut self, protocols: I) -> Self
212    where
213        I: IntoIterator,
214        I::Item: Into<Cow<'static, str>>,
215    {
216        if let Some(req_protocols) = self
217            .sec_websocket_protocol
218            .as_ref()
219            .and_then(|p| p.to_str().ok())
220        {
221            self.protocol = protocols
222                .into_iter()
223                .map(Into::into)
224                .find(|protocol| {
225                    req_protocols
226                        .split(',')
227                        .any(|req_protocol| req_protocol.trim() == protocol)
228                })
229                .map(|protocol| match protocol {
230                    Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
231                    Cow::Borrowed(s) => HeaderValue::from_static(s),
232                });
233        }
234
235        self
236    }
237
238    /// Finalize upgrading the connection and call the provided callback with
239    /// the stream.
240    ///
241    /// When using `WebSocketUpgrade`, the response produced by this method
242    /// should be returned from the handler. See the [module docs](self) for an
243    /// example.
244    pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
245    where
246        F: FnOnce(WebSocket) -> Fut + Send + 'static,
247        Fut: Future<Output = ()> + Send + 'static,
248        C: OnFailedUpdgrade,
249    {
250        let on_upgrade = self.on_upgrade;
251        let config = self.config;
252        let on_failed_upgrade = self.on_failed_upgrade;
253
254        let protocol = self.protocol.clone();
255
256        tokio::spawn(async move {
257            let upgraded = match on_upgrade.await {
258                Ok(upgraded) => upgraded,
259                Err(err) => {
260                    on_failed_upgrade.call(err);
261                    return;
262                }
263            };
264
265            let socket =
266                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
267                    .await;
268            let socket = WebSocket {
269                inner: socket,
270                protocol,
271            };
272            callback(socket).await;
273        });
274
275        #[allow(clippy::declare_interior_mutable_const)]
276        const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
277        #[allow(clippy::declare_interior_mutable_const)]
278        const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
279
280        let mut headers = HeaderMap::new();
281        headers.insert(header::CONNECTION, UPGRADE);
282        headers.insert(header::UPGRADE, WEBSOCKET);
283        headers.insert(
284            header::SEC_WEBSOCKET_ACCEPT,
285            sign(self.sec_websocket_key.as_bytes()),
286        );
287
288        if let Some(protocol) = self.protocol {
289            headers.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
290        }
291
292        (StatusCode::SWITCHING_PROTOCOLS, headers).into_response()
293    }
294
295    /// Provide a callback to call if upgrading the connection fails.
296    ///
297    /// The connection upgrade is performed in a background task. If that fails this callback
298    /// will be called.
299    ///
300    /// By default any errors will be silently ignored.
301    ///
302    /// # Example
303    ///
304    /// ```
305    /// use axum::response::Response;
306    /// use axum_tungstenite::WebSocketUpgrade;
307    ///
308    /// async fn handler(ws: WebSocketUpgrade) -> Response {
309    ///     ws.on_failed_upgrade(|error| {
310    ///         report_error(error);
311    ///     })
312    ///     .on_upgrade(|socket| async { /* ... */ })
313    /// }
314    /// #
315    /// # fn report_error(_: hyper::Error) {}
316    /// ```
317    pub fn on_failed_upgrade<C2>(self, callback: C2) -> WebSocketUpgrade<C2>
318    where
319        C2: OnFailedUpdgrade,
320    {
321        WebSocketUpgrade {
322            config: self.config,
323            protocol: self.protocol,
324            sec_websocket_key: self.sec_websocket_key,
325            on_upgrade: self.on_upgrade,
326            on_failed_upgrade: callback,
327            sec_websocket_protocol: self.sec_websocket_protocol,
328        }
329    }
330}
331
332#[async_trait]
333impl<S> FromRequestParts<S> for WebSocketUpgrade
334where
335    S: Sync,
336{
337    type Rejection = WebSocketUpgradeRejection;
338
339    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
340        if parts.method != Method::GET {
341            return Err(MethodNotGet.into());
342        }
343
344        if !header_contains(parts, header::CONNECTION, "upgrade") {
345            return Err(InvalidConnectionHeader.into());
346        }
347
348        if !header_eq(parts, header::UPGRADE, "websocket") {
349            return Err(InvalidUpgradeHeader.into());
350        }
351
352        if !header_eq(parts, header::SEC_WEBSOCKET_VERSION, "13") {
353            return Err(InvalidWebSocketVersionHeader.into());
354        }
355
356        let sec_websocket_key = if let Some(key) = parts.headers.remove(header::SEC_WEBSOCKET_KEY) {
357            key
358        } else {
359            return Err(WebSocketKeyHeaderMissing.into());
360        };
361
362        let on_upgrade = parts.extensions.remove::<OnUpgrade>().unwrap();
363
364        let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
365
366        Ok(Self {
367            config: Default::default(),
368            protocol: None,
369            sec_websocket_key,
370            on_upgrade,
371            on_failed_upgrade: DefaultOnFailedUpdgrade,
372            sec_websocket_protocol,
373        })
374    }
375}
376
377fn header_eq(req: &Parts, key: HeaderName, value: &'static str) -> bool {
378    if let Some(header) = req.headers.get(&key) {
379        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
380    } else {
381        false
382    }
383}
384
385fn header_contains(req: &Parts, key: HeaderName, value: &'static str) -> bool {
386    let header = if let Some(header) = req.headers.get(&key) {
387        header
388    } else {
389        return false;
390    };
391
392    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
393        header.to_ascii_lowercase().contains(value)
394    } else {
395        false
396    }
397}
398
399/// A stream of WebSocket messages.
400#[derive(Debug)]
401pub struct WebSocket {
402    inner: WebSocketStream<Upgraded>,
403    protocol: Option<HeaderValue>,
404}
405
406impl WebSocket {
407    /// Consume `self` and get the inner [`tokio_tungstenite::WebSocketStream`].
408    pub fn into_inner(self) -> WebSocketStream<Upgraded> {
409        self.inner
410    }
411
412    /// Receive another message.
413    ///
414    /// Returns `None` if the stream has closed.
415    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
416        self.next().await
417    }
418
419    /// Send a message.
420    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
421        self.inner.send(msg).await
422    }
423
424    /// Gracefully close this WebSocket.
425    pub async fn close(mut self) -> Result<(), Error> {
426        self.inner.close(None).await
427    }
428
429    /// Return the selected WebSocket subprotocol, if one has been chosen.
430    pub fn protocol(&self) -> Option<&HeaderValue> {
431        self.protocol.as_ref()
432    }
433}
434
435impl Stream for WebSocket {
436    type Item = Result<Message, Error>;
437
438    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
439        self.inner.poll_next_unpin(cx)
440    }
441}
442
443impl Sink<Message> for WebSocket {
444    type Error = Error;
445
446    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
447        Pin::new(&mut self.inner).poll_ready(cx)
448    }
449
450    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
451        Pin::new(&mut self.inner).start_send(item)
452    }
453
454    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
455        Pin::new(&mut self.inner).poll_flush(cx)
456    }
457
458    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
459        Pin::new(&mut self.inner).poll_close(cx)
460    }
461}
462
463fn sign(key: &[u8]) -> HeaderValue {
464    use base64::engine::Engine as _;
465
466    let mut sha1 = Sha1::default();
467    sha1.update(key);
468    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
469    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
470    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
471}
472
473/// What to do when a connection upgrade fails.
474///
475/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
476pub trait OnFailedUpdgrade: Send + 'static {
477    /// Call the callback.
478    fn call(self, error: hyper::Error);
479}
480
481impl<F> OnFailedUpdgrade for F
482where
483    F: FnOnce(hyper::Error) + Send + 'static,
484{
485    fn call(self, error: hyper::Error) {
486        self(error)
487    }
488}
489
490/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
491///
492/// It simply ignores the error.
493#[non_exhaustive]
494#[derive(Debug)]
495pub struct DefaultOnFailedUpdgrade;
496
497impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
498    #[inline]
499    fn call(self, _error: hyper::Error) {}
500}
501
502pub mod rejection {
503    //! WebSocket specific rejections.
504
505    use super::*;
506
507    macro_rules! define_rejection {
508        (
509            #[status = $status:ident]
510            #[body = $body:expr]
511            $(#[$m:meta])*
512            pub struct $name:ident;
513        ) => {
514            $(#[$m])*
515            #[derive(Debug)]
516            #[non_exhaustive]
517            pub struct $name;
518
519            impl IntoResponse for $name {
520                fn into_response(self) -> Response {
521                    (http::StatusCode::$status, $body).into_response()
522                }
523            }
524
525            impl std::fmt::Display for $name {
526                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527                    write!(f, "{}", $body)
528                }
529            }
530
531            impl std::error::Error for $name {}
532        };
533    }
534
535    define_rejection! {
536        #[status = METHOD_NOT_ALLOWED]
537        #[body = "Request method must be `GET`"]
538        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
539        pub struct MethodNotGet;
540    }
541
542    define_rejection! {
543        #[status = BAD_REQUEST]
544        #[body = "Connection header did not include 'upgrade'"]
545        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
546        pub struct InvalidConnectionHeader;
547    }
548
549    define_rejection! {
550        #[status = BAD_REQUEST]
551        #[body = "`Upgrade` header did not include 'websocket'"]
552        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
553        pub struct InvalidUpgradeHeader;
554    }
555
556    define_rejection! {
557        #[status = BAD_REQUEST]
558        #[body = "`Sec-WebSocket-Version` header did not include '13'"]
559        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
560        pub struct InvalidWebSocketVersionHeader;
561    }
562
563    define_rejection! {
564        #[status = BAD_REQUEST]
565        #[body = "`Sec-WebSocket-Key` header missing"]
566        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
567        pub struct WebSocketKeyHeaderMissing;
568    }
569
570    macro_rules! composite_rejection {
571        (
572            $(#[$m:meta])*
573            pub enum $name:ident {
574                $($variant:ident),+
575                $(,)?
576            }
577        ) => {
578            $(#[$m])*
579            #[derive(Debug)]
580            #[non_exhaustive]
581            pub enum $name {
582                $(
583                    #[allow(missing_docs)]
584                    $variant($variant)
585                ),+
586            }
587
588            impl IntoResponse for $name {
589                fn into_response(self) -> Response {
590                    match self {
591                        $(
592                            Self::$variant(inner) => inner.into_response(),
593                        )+
594                    }
595                }
596            }
597
598            $(
599                impl From<$variant> for $name {
600                    fn from(inner: $variant) -> Self {
601                        Self::$variant(inner)
602                    }
603                }
604            )+
605
606            impl std::fmt::Display for $name {
607                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608                    match self {
609                        $(
610                            Self::$variant(inner) => write!(f, "{}", inner),
611                        )+
612                    }
613                }
614            }
615
616            impl std::error::Error for $name {
617                fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
618                    match self {
619                        $(
620                            Self::$variant(inner) => Some(inner),
621                        )+
622                    }
623                }
624            }
625        };
626    }
627
628    composite_rejection! {
629        /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
630        ///
631        /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
632        /// extractor can fail.
633        pub enum WebSocketUpgradeRejection {
634            MethodNotGet,
635            InvalidConnectionHeader,
636            InvalidUpgradeHeader,
637            InvalidWebSocketVersionHeader,
638            WebSocketKeyHeaderMissing,
639        }
640    }
641}