axum_typed_websockets/
lib.rs

1//! [`axum::extract::ws`] with type safe messages.
2//!
3//! # Example
4//!
5//! ```rust
6//! use axum::{
7//!     Router,
8//!     response::IntoResponse,
9//!     routing::get,
10//! };
11//! use axum_typed_websockets::{Message, WebSocket, WebSocketUpgrade};
12//! use serde::{Serialize, Deserialize};
13//! use std::time::Instant;
14//!
15//! // Make a regular axum router
16//! let app = Router::new().route("/ws", get(handler));
17//!
18//! # async {
19//! // Run it!
20//! axum::serve(
21//!     tokio::net::TcpListener::bind("0.0.0.0:3000")
22//!         .await
23//!         .unwrap(),
24//!     app.into_make_service()
25//! )
26//! .await
27//! .unwrap();
28//! # };
29//!
30//! async fn handler(
31//!     // Upgrade the request to a WebSocket connection where the server sends
32//!     // messages of type `ServerMsg` and the clients sends `ClientMsg`
33//!     ws: WebSocketUpgrade<ServerMsg, ClientMsg>,
34//! ) -> impl IntoResponse {
35//!     ws.on_upgrade(ping_pong_socket)
36//! }
37//!
38//! // Send a ping and measure how long time it takes to get a pong back
39//! async fn ping_pong_socket(mut socket: WebSocket<ServerMsg, ClientMsg>) {
40//!     let start = Instant::now();
41//!     socket.send(Message::Item(ServerMsg::Ping)).await.ok();
42//!
43//!     if let Some(msg) = socket.recv().await {
44//!         match msg {
45//!             Ok(Message::Item(ClientMsg::Pong)) => {
46//!                 println!("ping: {:?}", start.elapsed());
47//!             },
48//!             Ok(_) => {},
49//!             Err(err) => {
50//!                 eprintln!("got error: {}", err);
51//!             },
52//!         }
53//!     }
54//! }
55//!
56//! #[derive(Debug, Serialize)]
57//! enum ServerMsg {
58//!     Ping,
59//! }
60//!
61//! #[derive(Debug, Deserialize)]
62//! enum ClientMsg {
63//!     Pong,
64//! }
65//! ```
66//!
67//! # Feature flags
68//!
69//! The following features are available:
70//!
71//! - `json`: Enables [`JsonCodec`] which encodes message as JSON using
72//! `serde_json`. Enabled by default.
73
74#![warn(
75    clippy::all,
76    clippy::dbg_macro,
77    clippy::todo,
78    clippy::empty_enum,
79    clippy::enum_glob_use,
80    clippy::mem_forget,
81    clippy::unused_self,
82    clippy::filter_map_next,
83    clippy::needless_continue,
84    clippy::needless_borrow,
85    clippy::match_wildcard_for_single_variants,
86    clippy::if_let_mutex,
87    clippy::mismatched_target_os,
88    clippy::await_holding_lock,
89    clippy::match_on_vec_items,
90    clippy::imprecise_flops,
91    clippy::suboptimal_flops,
92    clippy::lossy_float_literal,
93    clippy::rest_pat_in_fully_bound_structs,
94    clippy::fn_params_excessive_bools,
95    clippy::exit,
96    clippy::inefficient_to_string,
97    clippy::linkedlist,
98    clippy::macro_use_imports,
99    clippy::option_option,
100    clippy::verbose_file_reads,
101    clippy::unnested_or_patterns,
102    rust_2018_idioms,
103    future_incompatible,
104    nonstandard_style,
105    missing_debug_implementations,
106    missing_docs
107)]
108#![deny(unreachable_pub, private_interfaces, private_bounds)]
109#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
110#![forbid(unsafe_code)]
111#![cfg_attr(docsrs, feature(doc_cfg))]
112#![cfg_attr(test, allow(clippy::float_cmp))]
113
114use axum::{
115    async_trait,
116    extract::{ws, FromRequestParts},
117    http::request::Parts,
118    response::IntoResponse,
119};
120use futures_util::{Sink, SinkExt, Stream, StreamExt};
121use serde::{de::DeserializeOwned, Serialize};
122use std::{
123    error::Error as StdError,
124    fmt,
125    future::Future,
126    marker::PhantomData,
127    pin::Pin,
128    task::{Context, Poll},
129};
130
131#[allow(unused_macros)]
132macro_rules! with_and_without_json {
133    (
134        $(#[$m:meta])*
135        pub struct $name:ident<S, R, C = JsonCodec> {
136            $(
137                $ident:ident : $ty:ty,
138            )*
139        }
140    ) => {
141        $(#[$m])*
142        #[cfg(feature = "json")]
143        pub struct $name<S, R, C = JsonCodec> {
144            $(
145                $ident : $ty,
146            )*
147        }
148
149        $(#[$m])*
150        #[cfg(not(feature = "json"))]
151        pub struct $name<S, R, C> {
152            $(
153                $ident : $ty,
154            )*
155        }
156    }
157}
158
159with_and_without_json! {
160    /// A version of [`axum::extract::ws::WebSocketUpgrade`] with type safe
161    /// messages.
162    ///
163    /// # Type parameters
164    ///
165    /// - `S` - The message sent from the server to the client.
166    /// - `R` - The message sent from the client to the server.
167    /// - `C` - The [`Codec`] used to encode and decode messages. Defaults to
168    /// [`JsonCodec`] which serializes messages with `serde_json`.
169    pub struct WebSocketUpgrade<S, R, C = JsonCodec> {
170        upgrade: ws::WebSocketUpgrade,
171        _marker: PhantomData<fn() -> (S, R, C)>,
172    }
173}
174
175#[async_trait]
176impl<S, R, C, B> FromRequestParts<B> for WebSocketUpgrade<S, R, C>
177where
178    B: Send + Sync,
179{
180    type Rejection = <ws::WebSocketUpgrade as FromRequestParts<B>>::Rejection;
181
182    async fn from_request_parts(parts: &mut Parts, state: &B) -> Result<Self, Self::Rejection> {
183        let upgrade = FromRequestParts::from_request_parts(parts, state).await?;
184        Ok(Self {
185            upgrade,
186            _marker: PhantomData,
187        })
188    }
189}
190
191impl<S, R, C> WebSocketUpgrade<S, R, C> {
192    /// Finalize upgrading the connection and call the provided callback with
193    /// the stream.
194    ///
195    /// This is analagous to [`axum::extract::ws::WebSocketUpgrade::on_upgrade`].
196    pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
197    where
198        F: FnOnce(WebSocket<S, R, C>) -> Fut + Send + 'static,
199        Fut: Future<Output = ()> + Send + 'static,
200        S: Send,
201        R: Send,
202    {
203        self.upgrade
204            .on_upgrade(|socket| async move {
205                let socket = WebSocket {
206                    socket,
207                    _marker: PhantomData,
208                };
209                callback(socket).await
210            })
211            .into_response()
212    }
213
214    /// Apply a transformation to the inner [`axum::extract::ws::WebSocketUpgrade`].
215    ///
216    /// This can be used to apply configuration.
217    pub fn map<F>(mut self, f: F) -> Self
218    where
219        F: FnOnce(ws::WebSocketUpgrade) -> ws::WebSocketUpgrade,
220    {
221        self.upgrade = f(self.upgrade);
222        self
223    }
224
225    /// Get the inner axum [`axum::extract::ws::WebSocketUpgrade`].
226    pub fn into_inner(self) -> ws::WebSocketUpgrade {
227        self.upgrade
228    }
229}
230
231impl<S, R, C> fmt::Debug for WebSocketUpgrade<S, R, C> {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("WebSocketUpgrade")
234            .field("upgrade", &self.upgrade)
235            .finish()
236    }
237}
238
239with_and_without_json! {
240    /// A version of [`axum::extract::ws::WebSocket`] with type safe
241    /// messages.
242    pub struct WebSocket<S, R, C = JsonCodec> {
243        socket: ws::WebSocket,
244        _marker: PhantomData<fn() -> (S, R, C)>,
245    }
246}
247
248impl<S, R, C> WebSocket<S, R, C> {
249    /// Receive another message.
250    ///
251    /// Returns `None` if the stream stream has closed.
252    ///
253    /// This is analagous to [`axum::extract::ws::WebSocket::recv`] but with a
254    /// statically typed message.
255    pub async fn recv(&mut self) -> Option<Result<Message<R>, Error<C::Error>>>
256    where
257        R: DeserializeOwned,
258        C: Codec,
259    {
260        self.next().await
261    }
262
263    /// Send a message.
264    ///
265    /// This is analagous to [`axum::extract::ws::WebSocket::send`] but with a
266    /// statically typed message.
267    pub async fn send(&mut self, msg: Message<S>) -> Result<(), Error<C::Error>>
268    where
269        S: Serialize,
270        C: Codec,
271    {
272        SinkExt::send(self, msg).await
273    }
274
275    /// Gracefully close this WebSocket.
276    ///
277    /// This is analagous to [`axum::extract::ws::WebSocket::close`].
278    pub async fn close(self) -> Result<(), Error<C::Error>>
279    where
280        C: Codec,
281    {
282        self.socket.close().await.map_err(Error::Ws)
283    }
284
285    /// Get the inner axum [`axum::extract::ws::WebSocket`].
286    pub fn into_inner(self) -> ws::WebSocket {
287        self.socket
288    }
289}
290
291impl<S, R, C> fmt::Debug for WebSocket<S, R, C> {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        f.debug_struct("WebSocket")
294            .field("socket", &self.socket)
295            .finish()
296    }
297}
298
299impl<S, R, C> Stream for WebSocket<S, R, C>
300where
301    R: DeserializeOwned,
302    C: Codec,
303{
304    type Item = Result<Message<R>, Error<C::Error>>;
305
306    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
307        let msg = futures_util::ready!(Pin::new(&mut self.socket)
308            .poll_next(cx)
309            .map_err(Error::Ws)?);
310
311        if let Some(msg) = msg {
312            let msg = match msg {
313                ws::Message::Text(msg) => msg.into_bytes(),
314                ws::Message::Binary(bytes) => bytes,
315                ws::Message::Close(frame) => {
316                    return Poll::Ready(Some(Ok(Message::Close(frame))));
317                }
318                ws::Message::Ping(buf) => {
319                    return Poll::Ready(Some(Ok(Message::Ping(buf))));
320                }
321                ws::Message::Pong(buf) => {
322                    return Poll::Ready(Some(Ok(Message::Pong(buf))));
323                }
324            };
325
326            let msg = C::decode(msg).map(Message::Item).map_err(Error::Codec);
327            Poll::Ready(Some(msg))
328        } else {
329            Poll::Ready(None)
330        }
331    }
332}
333
334impl<S, R, C> Sink<Message<S>> for WebSocket<S, R, C>
335where
336    S: Serialize,
337    C: Codec,
338{
339    type Error = Error<C::Error>;
340
341    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342        Pin::new(&mut self.socket).poll_ready(cx).map_err(Error::Ws)
343    }
344
345    fn start_send(mut self: Pin<&mut Self>, msg: Message<S>) -> Result<(), Self::Error> {
346        let msg = match msg {
347            Message::Item(buf) => ws::Message::Binary(C::encode(buf).map_err(Error::Codec)?),
348            Message::Ping(buf) => ws::Message::Ping(buf),
349            Message::Pong(buf) => ws::Message::Pong(buf),
350            Message::Close(frame) => ws::Message::Close(frame),
351        };
352
353        Pin::new(&mut self.socket)
354            .start_send(msg)
355            .map_err(Error::Ws)
356    }
357
358    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359        Pin::new(&mut self.socket).poll_flush(cx).map_err(Error::Ws)
360    }
361
362    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
363        Pin::new(&mut self.socket).poll_close(cx).map_err(Error::Ws)
364    }
365}
366
367/// Trait for encoding and decoding WebSocket messages.
368///
369/// This allows you to customize how messages are encoded when sent over the
370/// wire.
371pub trait Codec {
372    /// The errors that can happen when using this codec.
373    type Error;
374
375    /// Encode a message.
376    fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
377    where
378        S: Serialize;
379
380    /// Decode a message.
381    fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
382    where
383        R: DeserializeOwned;
384}
385
386/// A [`Codec`] that serializes messages as JSON using `serde_json`.
387#[cfg(feature = "json")]
388#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
389#[derive(Debug)]
390#[non_exhaustive]
391pub struct JsonCodec;
392
393#[cfg(feature = "json")]
394#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
395impl Codec for JsonCodec {
396    type Error = serde_json::Error;
397
398    fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
399    where
400        S: Serialize,
401    {
402        serde_json::to_vec(&msg)
403    }
404
405    fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
406    where
407        R: DeserializeOwned,
408    {
409        serde_json::from_slice(&buf)
410    }
411}
412
413/// Errors that can happen when using this library.
414#[derive(Debug)]
415pub enum Error<E> {
416    /// Something went wrong with the WebSocket.
417    Ws(axum::Error),
418    /// Something went wrong with the [`Codec`].
419    Codec(E),
420}
421
422impl<E> fmt::Display for Error<E>
423where
424    E: fmt::Display,
425{
426    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427        match self {
428            Error::Ws(inner) => inner.fmt(f),
429            Error::Codec(inner) => inner.fmt(f),
430        }
431    }
432}
433
434impl<E> StdError for Error<E>
435where
436    E: StdError + 'static,
437{
438    fn source(&self) -> Option<&(dyn StdError + 'static)> {
439        match self {
440            Error::Ws(inner) => Some(inner),
441            Error::Codec(inner) => Some(inner),
442        }
443    }
444}
445
446/// A WebSocket message contain a value of a known type.
447#[derive(Debug, Eq, PartialEq, Clone)]
448pub enum Message<T> {
449    /// An item of type `T`.
450    Item(T),
451    /// A ping message with the specified payload
452    ///
453    /// The payload here must have a length less than 125 bytes
454    Ping(Vec<u8>),
455    /// A pong message with the specified payload
456    ///
457    /// The payload here must have a length less than 125 bytes
458    Pong(Vec<u8>),
459    /// A close message with the optional close frame.
460    Close(Option<ws::CloseFrame<'static>>),
461}