musli_web/
axum08.rs

1//! The server implementation for [axum].
2//!
3//! Use [`server()`] to set up the server and feed it incoming requests.
4//!
5//! [axum]: <https://docs.rs/axum>
6
7use core::pin::Pin;
8use core::task::Poll;
9use core::task::{Context, ready};
10
11use bytes::Bytes;
12
13use axum_core05::Error;
14use axum08::extract::ws::{CloseFrame, Message, WebSocket};
15use futures_core03::Stream;
16use futures_sink03::Sink;
17
18use crate::ws::{self, Handler, Server, ServerImpl, SocketImpl};
19
20/// Construct a new axum server with the specified handler.
21///
22/// # Examples
23///
24/// ```
25/// # extern crate axum08 as axum;
26/// use std::error::Error;
27/// use std::pin::pin;
28///
29/// use axum::Router;
30/// use axum::extract::State;
31/// use axum::extract::ws::{WebSocket, WebSocketUpgrade};
32/// use axum::response::Response;
33/// use axum::routing::any;
34/// use tokio::sync::broadcast::Sender;
35/// use tokio::time::{self, Duration};
36///
37/// use musli_web::api::MessageId;
38/// use musli_web::axum08;
39/// use musli_web::ws;
40///
41/// mod api {
42///     use musli::{Decode, Encode};
43///     use musli_web::api;
44///
45///     #[derive(Encode, Decode)]
46///     pub struct HelloRequest<'de> {
47///         pub message: &'de str,
48///     }
49///
50///     #[derive(Encode, Decode)]
51///     pub struct HelloResponse<'de> {
52///         pub message: &'de str,
53///     }
54///
55///     #[derive(Encode, Decode)]
56///     pub struct TickEvent<'de> {
57///         pub message: &'de str,
58///         pub tick: u32,
59///     }
60///
61///     api::define! {
62///         pub type Hello;
63///
64///         impl Endpoint for Hello {
65///             impl<'de> Request for HelloRequest<'de>;
66///             type Response<'de> = HelloResponse<'de>;
67///         }
68///
69///         pub type Tick;
70///
71///         impl Broadcast for Tick {
72///             impl<'de> Event for TickEvent<'de>;
73///         }
74///     }
75/// }
76///
77/// #[derive(Debug, Clone)]
78/// enum Broadcast {
79///     Tick { tick: u32 },
80/// }
81///
82/// struct MyHandler;
83///
84/// impl ws::Handler for MyHandler {
85///     type Id = api::Request;
86///     type Response = Option<()>;
87///
88///     async fn handle(
89///         &mut self,
90///         id: Self::Id,
91///         incoming: &mut ws::Incoming<'_>,
92///         outgoing: &mut ws::Outgoing<'_>,
93///     ) -> Self::Response {
94///         tracing::info!("Handling: {id:?}");
95///
96///         match id {
97///             api::Request::Hello => {
98///                 let request = incoming.read::<api::HelloRequest<'_>>()?;
99///
100///                 outgoing.write(api::HelloResponse {
101///                     message: request.message,
102///                 });
103///
104///                 Some(())
105///             }
106///         }
107///     }
108/// }
109///
110/// async fn handler(ws: WebSocketUpgrade, State(sender): State<Sender<Broadcast>>) -> Response {
111///     ws.on_upgrade(move |socket: WebSocket| async move {
112///         let mut subscribe = sender.subscribe();
113///
114///         let mut server = pin!(axum08::server(socket, MyHandler));
115///
116///         loop {
117///             tokio::select! {
118///                 m = subscribe.recv() => {
119///                     let Ok(message) = m else {
120///                         continue;
121///                     };
122///
123///                     let result = match message {
124///                         Broadcast::Tick { tick } => {
125///                             server.as_mut().broadcast(api::TickEvent { message: "tick", tick })
126///                         },
127///                     };
128///
129///                     if let Err(error) = result {
130///                         tracing::error!("Broadcast failed: {error}");
131///
132///                         let mut error = error.source();
133///
134///                         while let Some(e) = error.take() {
135///                             tracing::error!("Caused by: {e}");
136///                             error = e.source();
137///                         }
138///                     }
139///                 }
140///                 result = server.as_mut().run() => {
141///                     if let Err(error) = result {
142///                         tracing::error!("Websocket error: {error}");
143///
144///                         let mut error = error.source();
145///
146///                         while let Some(e) = error.take() {
147///                             tracing::error!("Caused by: {e}");
148///                             error = e.source();
149///                         }
150///                     }
151///
152///                     break;
153///                 }
154///             }
155///         }
156///     })
157/// }
158/// ```
159#[inline]
160pub fn server<H>(socket: WebSocket, handler: H) -> Server<AxumServer, H>
161where
162    H: Handler,
163{
164    Server::new(socket, handler)
165}
166
167impl crate::ws::server_sealed::Sealed for AxumServer {}
168
169/// Marker type used in combination with [`Server`] to indicate that the
170/// implementation uses axum.
171///
172/// See [`server()`] for how this is constructed and used.
173#[non_exhaustive]
174pub enum AxumServer {}
175
176impl ServerImpl for AxumServer {
177    type Error = Error;
178    type Message = Message;
179    type Socket = WebSocket;
180
181    #[inline]
182    fn ping(data: Bytes) -> Self::Message {
183        Message::Ping(data)
184    }
185
186    #[inline]
187    fn pong(data: Bytes) -> Self::Message {
188        Message::Pong(data)
189    }
190
191    #[inline]
192    fn binary(data: &[u8]) -> Self::Message {
193        Message::Binary(Bytes::from(data.to_vec()))
194    }
195
196    #[inline]
197    fn close(code: u16, reason: &str) -> Self::Message {
198        Message::Close(Some(CloseFrame {
199            code,
200            reason: reason.into(),
201        }))
202    }
203}
204
205impl crate::ws::socket_sealed::Sealed for WebSocket {}
206
207impl SocketImpl for WebSocket {
208    type Message = Message;
209    type Error = Error;
210
211    #[inline]
212    #[allow(private_interfaces)]
213    fn poll_next(
214        self: Pin<&mut Self>,
215        ctx: &mut Context<'_>,
216    ) -> Poll<Option<Result<ws::Message, Self::Error>>> {
217        let Some(result) = ready!(Stream::poll_next(self, ctx)) else {
218            return Poll::Ready(None);
219        };
220
221        let message = match result {
222            Ok(message) => message,
223            Err(err) => return Poll::Ready(Some(Err(err))),
224        };
225
226        let message = match message {
227            Message::Text(..) => ws::Message::Text,
228            Message::Binary(data) => ws::Message::Binary(data),
229            Message::Ping(data) => ws::Message::Ping(data),
230            Message::Pong(data) => ws::Message::Pong(data),
231            Message::Close(..) => ws::Message::Close,
232        };
233
234        Poll::Ready(Some(Ok(message)))
235    }
236
237    #[inline]
238    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239        Sink::poll_ready(self, cx)
240    }
241
242    #[inline]
243    fn start_send(self: Pin<&mut Self>, message: Self::Message) -> Result<(), Self::Error> {
244        Sink::start_send(self, message)
245    }
246
247    #[inline]
248    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249        Sink::poll_flush(self, cx)
250    }
251}