musli_web/axum08.rs
1//! The server implementation for [axum].
2//!
3//! Use [`server()`] to set up the server and feed it incoming requests.
4//!
5//! [axum]: <https://docs.rs/axum>
6
7use core::pin::Pin;
8use core::task::Poll;
9use core::task::{Context, ready};
10
11use bytes::Bytes;
12
13use axum_core05::Error;
14use axum08::extract::ws::{CloseFrame, Message, WebSocket};
15use futures_core03::Stream;
16use futures_sink03::Sink;
17
18use crate::ws::{self, Handler, Server, ServerImpl, SocketImpl};
19
20/// Construct a new axum server with the specified handler.
21///
22/// # Examples
23///
24/// ```
25/// # extern crate axum08 as axum;
26/// use std::error::Error;
27/// use std::pin::pin;
28///
29/// use axum::Router;
30/// use axum::extract::State;
31/// use axum::extract::ws::{WebSocket, WebSocketUpgrade};
32/// use axum::response::Response;
33/// use axum::routing::any;
34/// use tokio::sync::broadcast::Sender;
35/// use tokio::time::{self, Duration};
36///
37/// use musli_web::api::MessageId;
38/// use musli_web::axum08;
39/// use musli_web::ws;
40///
41/// mod api {
42/// use musli::{Decode, Encode};
43/// use musli_web::api;
44///
45/// #[derive(Encode, Decode)]
46/// pub struct HelloRequest<'de> {
47/// pub message: &'de str,
48/// }
49///
50/// #[derive(Encode, Decode)]
51/// pub struct HelloResponse<'de> {
52/// pub message: &'de str,
53/// }
54///
55/// #[derive(Encode, Decode)]
56/// pub struct TickEvent<'de> {
57/// pub message: &'de str,
58/// pub tick: u32,
59/// }
60///
61/// api::define! {
62/// pub type Hello;
63///
64/// impl Endpoint for Hello {
65/// impl<'de> Request for HelloRequest<'de>;
66/// type Response<'de> = HelloResponse<'de>;
67/// }
68///
69/// pub type Tick;
70///
71/// impl Broadcast for Tick {
72/// impl<'de> Event for TickEvent<'de>;
73/// }
74/// }
75/// }
76///
77/// #[derive(Debug, Clone)]
78/// enum Broadcast {
79/// Tick { tick: u32 },
80/// }
81///
82/// struct MyHandler;
83///
84/// impl ws::Handler for MyHandler {
85/// type Id = api::Request;
86/// type Response = Option<()>;
87///
88/// async fn handle(
89/// &mut self,
90/// id: Self::Id,
91/// incoming: &mut ws::Incoming<'_>,
92/// outgoing: &mut ws::Outgoing<'_>,
93/// ) -> Self::Response {
94/// tracing::info!("Handling: {id:?}");
95///
96/// match id {
97/// api::Request::Hello => {
98/// let request = incoming.read::<api::HelloRequest<'_>>()?;
99///
100/// outgoing.write(api::HelloResponse {
101/// message: request.message,
102/// });
103///
104/// Some(())
105/// }
106/// api::Request::Unknown(id) => {
107/// None
108/// }
109/// }
110/// }
111/// }
112///
113/// async fn handler(ws: WebSocketUpgrade, State(sender): State<Sender<Broadcast>>) -> Response {
114/// ws.on_upgrade(move |socket: WebSocket| async move {
115/// let mut subscribe = sender.subscribe();
116///
117/// let mut server = axum08::server(socket, MyHandler);
118///
119/// loop {
120/// tokio::select! {
121/// m = subscribe.recv() => {
122/// let Ok(message) = m else {
123/// continue;
124/// };
125///
126/// let result = match message {
127/// Broadcast::Tick { tick } => {
128/// server.broadcast(api::TickEvent { message: "tick", tick })
129/// },
130/// };
131///
132/// if let Err(error) = result {
133/// tracing::error!("Broadcast failed: {error}");
134///
135/// let mut error = error.source();
136///
137/// while let Some(e) = error.take() {
138/// tracing::error!("Caused by: {e}");
139/// error = e.source();
140/// }
141/// }
142/// }
143/// result = server.run() => {
144/// if let Err(error) = result {
145/// tracing::error!("Websocket error: {error}");
146///
147/// let mut error = error.source();
148///
149/// while let Some(e) = error.take() {
150/// tracing::error!("Caused by: {e}");
151/// error = e.source();
152/// }
153/// }
154///
155/// break;
156/// }
157/// }
158/// }
159/// })
160/// }
161/// ```
162#[inline]
163pub fn server<H>(socket: WebSocket, handler: H) -> Server<AxumServer, H>
164where
165 H: Handler,
166{
167 Server::new(socket, handler)
168}
169
170impl crate::ws::server_sealed::Sealed for AxumServer {}
171
172/// Marker type used in combination with [`Server`] to indicate that the
173/// implementation uses axum.
174///
175/// See [`server()`] for how this is constructed and used.
176#[non_exhaustive]
177pub enum AxumServer {}
178
179impl ServerImpl for AxumServer {
180 type Error = Error;
181 type Message = Message;
182 type Socket = WebSocket;
183
184 #[inline]
185 fn ping(data: Bytes) -> Self::Message {
186 Message::Ping(data)
187 }
188
189 #[inline]
190 fn pong(data: Bytes) -> Self::Message {
191 Message::Pong(data)
192 }
193
194 #[inline]
195 fn binary(data: &[u8]) -> Self::Message {
196 Message::Binary(Bytes::from(data.to_vec()))
197 }
198
199 #[inline]
200 fn close(code: u16, reason: &str) -> Self::Message {
201 Message::Close(Some(CloseFrame {
202 code,
203 reason: reason.into(),
204 }))
205 }
206}
207
208impl crate::ws::socket_sealed::Sealed for WebSocket {}
209
210impl SocketImpl for WebSocket {
211 type Message = Message;
212 type Error = Error;
213
214 #[inline]
215 #[allow(private_interfaces)]
216 fn poll_next(
217 self: Pin<&mut Self>,
218 ctx: &mut Context<'_>,
219 ) -> Poll<Option<Result<ws::Message, Self::Error>>> {
220 let Some(result) = ready!(Stream::poll_next(self, ctx)) else {
221 return Poll::Ready(None);
222 };
223
224 let message = match result {
225 Ok(message) => message,
226 Err(err) => return Poll::Ready(Some(Err(err))),
227 };
228
229 let message = match message {
230 Message::Text(..) => ws::Message::Text,
231 Message::Binary(data) => ws::Message::Binary(data),
232 Message::Ping(data) => ws::Message::Ping(data),
233 Message::Pong(data) => ws::Message::Pong(data),
234 Message::Close(..) => ws::Message::Close,
235 };
236
237 Poll::Ready(Some(Ok(message)))
238 }
239
240 #[inline]
241 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242 Sink::poll_ready(self, cx)
243 }
244
245 #[inline]
246 fn start_send(self: Pin<&mut Self>, message: Self::Message) -> Result<(), Self::Error> {
247 Sink::start_send(self, message)
248 }
249
250 #[inline]
251 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252 Sink::poll_flush(self, cx)
253 }
254}