Skip to main content

musli_web/
axum08.rs

1//! The server implementation for [axum].
2//!
3//! Use [`server()`] to set up the server and feed it incoming requests.
4//!
5//! [axum]: <https://docs.rs/axum>
6
7use core::pin::Pin;
8use core::task::Poll;
9use core::task::{Context, ready};
10
11use bytes::Bytes;
12
13use axum_core05::Error;
14use axum08::extract::ws::{CloseFrame, Message, WebSocket};
15use futures_core03::Stream;
16use futures_sink03::Sink;
17
18use crate::ws::{self, Handler, Server, ServerImpl, SocketImpl};
19
20/// Construct a new axum server with the specified handler.
21///
22/// # Examples
23///
24/// ```
25/// # extern crate axum08 as axum;
26/// use std::error::Error;
27/// use std::pin::pin;
28///
29/// use axum::Router;
30/// use axum::extract::State;
31/// use axum::extract::ws::{WebSocket, WebSocketUpgrade};
32/// use axum::response::Response;
33/// use axum::routing::any;
34/// use tokio::sync::broadcast::Sender;
35/// use tokio::time::{self, Duration};
36///
37/// use musli_web::api::MessageId;
38/// use musli_web::axum08;
39/// use musli_web::ws;
40///
41/// mod api {
42///     use musli::{Decode, Encode};
43///     use musli_web::api;
44///
45///     #[derive(Encode, Decode)]
46///     pub struct HelloRequest<'de> {
47///         pub message: &'de str,
48///     }
49///
50///     #[derive(Encode, Decode)]
51///     pub struct HelloResponse<'de> {
52///         pub message: &'de str,
53///     }
54///
55///     #[derive(Encode, Decode)]
56///     pub struct TickEvent<'de> {
57///         pub message: &'de str,
58///         pub tick: u32,
59///     }
60///
61///     api::define! {
62///         pub type Hello;
63///
64///         impl Endpoint for Hello {
65///             impl<'de> Request for HelloRequest<'de>;
66///             type Response<'de> = HelloResponse<'de>;
67///         }
68///
69///         pub type Tick;
70///
71///         impl Broadcast for Tick {
72///             impl<'de> Event for TickEvent<'de>;
73///         }
74///     }
75/// }
76///
77/// #[derive(Debug, Clone)]
78/// enum Broadcast {
79///     Tick { tick: u32 },
80/// }
81///
82/// struct MyHandler;
83///
84/// impl ws::Handler for MyHandler {
85///     type Id = api::Request;
86///     type Response = Option<()>;
87///
88///     async fn handle(
89///         &mut self,
90///         id: Self::Id,
91///         incoming: &mut ws::Incoming<'_>,
92///         outgoing: &mut ws::Outgoing<'_>,
93///     ) -> Self::Response {
94///         tracing::info!("Handling: {id:?}");
95///
96///         match id {
97///             api::Request::Hello => {
98///                 let request = incoming.read::<api::HelloRequest<'_>>()?;
99///
100///                 outgoing.write(api::HelloResponse {
101///                     message: request.message,
102///                 });
103///
104///                 Some(())
105///             }
106///             api::Request::Unknown(id) => {
107///                 None
108///             }
109///         }
110///     }
111/// }
112///
113/// async fn handler(ws: WebSocketUpgrade, State(sender): State<Sender<Broadcast>>) -> Response {
114///     ws.on_upgrade(move |socket: WebSocket| async move {
115///         let mut subscribe = sender.subscribe();
116///
117///         let mut server = axum08::server(socket, MyHandler);
118///
119///         loop {
120///             tokio::select! {
121///                 m = subscribe.recv() => {
122///                     let Ok(message) = m else {
123///                         continue;
124///                     };
125///
126///                     let result = match message {
127///                         Broadcast::Tick { tick } => {
128///                             server.broadcast(api::TickEvent { message: "tick", tick })
129///                         },
130///                     };
131///
132///                     if let Err(error) = result {
133///                         tracing::error!("Broadcast failed: {error}");
134///
135///                         let mut error = error.source();
136///
137///                         while let Some(e) = error.take() {
138///                             tracing::error!("Caused by: {e}");
139///                             error = e.source();
140///                         }
141///                     }
142///                 }
143///                 result = server.run() => {
144///                     if let Err(error) = result {
145///                         tracing::error!("Websocket error: {error}");
146///
147///                         let mut error = error.source();
148///
149///                         while let Some(e) = error.take() {
150///                             tracing::error!("Caused by: {e}");
151///                             error = e.source();
152///                         }
153///                     }
154///
155///                     break;
156///                 }
157///             }
158///         }
159///     })
160/// }
161/// ```
162#[inline]
163pub fn server<H>(socket: WebSocket, handler: H) -> Server<AxumServer, H>
164where
165    H: Handler,
166{
167    Server::new(socket, handler)
168}
169
170impl crate::ws::server_sealed::Sealed for AxumServer {}
171
172/// Marker type used in combination with [`Server`] to indicate that the
173/// implementation uses axum.
174///
175/// See [`server()`] for how this is constructed and used.
176#[non_exhaustive]
177pub enum AxumServer {}
178
179impl ServerImpl for AxumServer {
180    type Error = Error;
181    type Message = Message;
182    type Socket = WebSocket;
183
184    #[inline]
185    fn ping(data: Bytes) -> Self::Message {
186        Message::Ping(data)
187    }
188
189    #[inline]
190    fn pong(data: Bytes) -> Self::Message {
191        Message::Pong(data)
192    }
193
194    #[inline]
195    fn binary(data: &[u8]) -> Self::Message {
196        Message::Binary(Bytes::from(data.to_vec()))
197    }
198
199    #[inline]
200    fn close(code: u16, reason: &str) -> Self::Message {
201        Message::Close(Some(CloseFrame {
202            code,
203            reason: reason.into(),
204        }))
205    }
206}
207
208impl crate::ws::socket_sealed::Sealed for WebSocket {}
209
210impl SocketImpl for WebSocket {
211    type Message = Message;
212    type Error = Error;
213
214    #[inline]
215    #[allow(private_interfaces)]
216    fn poll_next(
217        self: Pin<&mut Self>,
218        ctx: &mut Context<'_>,
219    ) -> Poll<Option<Result<ws::Message, Self::Error>>> {
220        let Some(result) = ready!(Stream::poll_next(self, ctx)) else {
221            return Poll::Ready(None);
222        };
223
224        let message = match result {
225            Ok(message) => message,
226            Err(err) => return Poll::Ready(Some(Err(err))),
227        };
228
229        let message = match message {
230            Message::Text(..) => ws::Message::Text,
231            Message::Binary(data) => ws::Message::Binary(data),
232            Message::Ping(data) => ws::Message::Ping(data),
233            Message::Pong(data) => ws::Message::Pong(data),
234            Message::Close(..) => ws::Message::Close,
235        };
236
237        Poll::Ready(Some(Ok(message)))
238    }
239
240    #[inline]
241    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242        Sink::poll_ready(self, cx)
243    }
244
245    #[inline]
246    fn start_send(self: Pin<&mut Self>, message: Self::Message) -> Result<(), Self::Error> {
247        Sink::start_send(self, message)
248    }
249
250    #[inline]
251    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252        Sink::poll_flush(self, cx)
253    }
254}