musli_web/axum08.rs
1//! The server implementation for [axum].
2//!
3//! Use [`server()`] to set up the server and feed it incoming requests.
4//!
5//! [axum]: <https://docs.rs/axum>
6
7use core::pin::Pin;
8use core::task::Poll;
9use core::task::{Context, ready};
10
11use bytes::Bytes;
12
13use axum_core05::Error;
14use axum08::extract::ws::{CloseFrame, Message, WebSocket};
15use futures_core03::Stream;
16use futures_sink03::Sink;
17
18use crate::ws::{self, Handler, Server, ServerImpl, SocketImpl};
19
20/// Construct a new axum server with the specified handler.
21///
22/// # Examples
23///
24/// ```
25/// # extern crate axum08 as axum;
26/// use std::error::Error;
27/// use std::pin::pin;
28///
29/// use axum::Router;
30/// use axum::extract::State;
31/// use axum::extract::ws::{WebSocket, WebSocketUpgrade};
32/// use axum::response::Response;
33/// use axum::routing::any;
34/// use tokio::sync::broadcast::Sender;
35/// use tokio::time::{self, Duration};
36///
37/// use musli_web::api::MessageId;
38/// use musli_web::axum08;
39/// use musli_web::ws;
40///
41/// mod api {
42/// use musli::{Decode, Encode};
43/// use musli_web::api;
44///
45/// #[derive(Encode, Decode)]
46/// pub struct HelloRequest<'de> {
47/// pub message: &'de str,
48/// }
49///
50/// #[derive(Encode, Decode)]
51/// pub struct HelloResponse<'de> {
52/// pub message: &'de str,
53/// }
54///
55/// #[derive(Encode, Decode)]
56/// pub struct TickEvent<'de> {
57/// pub message: &'de str,
58/// pub tick: u32,
59/// }
60///
61/// api::define! {
62/// pub type Hello;
63///
64/// impl Endpoint for Hello {
65/// impl<'de> Request for HelloRequest<'de>;
66/// type Response<'de> = HelloResponse<'de>;
67/// }
68///
69/// pub type Tick;
70///
71/// impl Broadcast for Tick {
72/// impl<'de> Event for TickEvent<'de>;
73/// }
74/// }
75/// }
76///
77/// #[derive(Debug, Clone)]
78/// enum Broadcast {
79/// Tick { tick: u32 },
80/// }
81///
82/// struct MyHandler;
83///
84/// impl ws::Handler for MyHandler {
85/// type Id = api::Request;
86/// type Response = Option<()>;
87///
88/// async fn handle(
89/// &mut self,
90/// id: Self::Id,
91/// incoming: &mut ws::Incoming<'_>,
92/// outgoing: &mut ws::Outgoing<'_>,
93/// ) -> Self::Response {
94/// tracing::info!("Handling: {id:?}");
95///
96/// match id {
97/// api::Request::Hello => {
98/// let request = incoming.read::<api::HelloRequest<'_>>()?;
99///
100/// outgoing.write(api::HelloResponse {
101/// message: request.message,
102/// });
103///
104/// Some(())
105/// }
106/// }
107/// }
108/// }
109///
110/// async fn handler(ws: WebSocketUpgrade, State(sender): State<Sender<Broadcast>>) -> Response {
111/// ws.on_upgrade(move |socket: WebSocket| async move {
112/// let mut subscribe = sender.subscribe();
113///
114/// let mut server = pin!(axum08::server(socket, MyHandler));
115///
116/// loop {
117/// tokio::select! {
118/// m = subscribe.recv() => {
119/// let Ok(message) = m else {
120/// continue;
121/// };
122///
123/// let result = match message {
124/// Broadcast::Tick { tick } => {
125/// server.as_mut().broadcast(api::TickEvent { message: "tick", tick })
126/// },
127/// };
128///
129/// if let Err(error) = result {
130/// tracing::error!("Broadcast failed: {error}");
131///
132/// let mut error = error.source();
133///
134/// while let Some(e) = error.take() {
135/// tracing::error!("Caused by: {e}");
136/// error = e.source();
137/// }
138/// }
139/// }
140/// result = server.as_mut().run() => {
141/// if let Err(error) = result {
142/// tracing::error!("Websocket error: {error}");
143///
144/// let mut error = error.source();
145///
146/// while let Some(e) = error.take() {
147/// tracing::error!("Caused by: {e}");
148/// error = e.source();
149/// }
150/// }
151///
152/// break;
153/// }
154/// }
155/// }
156/// })
157/// }
158/// ```
159#[inline]
160pub fn server<H>(socket: WebSocket, handler: H) -> Server<AxumServer, H>
161where
162 H: Handler,
163{
164 Server::new(socket, handler)
165}
166
167impl crate::ws::server_sealed::Sealed for AxumServer {}
168
169/// Marker type used in combination with [`Server`] to indicate that the
170/// implementation uses axum.
171///
172/// See [`server()`] for how this is constructed and used.
173#[non_exhaustive]
174pub enum AxumServer {}
175
176impl ServerImpl for AxumServer {
177 type Error = Error;
178 type Message = Message;
179 type Socket = WebSocket;
180
181 #[inline]
182 fn ping(data: Bytes) -> Self::Message {
183 Message::Ping(data)
184 }
185
186 #[inline]
187 fn pong(data: Bytes) -> Self::Message {
188 Message::Pong(data)
189 }
190
191 #[inline]
192 fn binary(data: &[u8]) -> Self::Message {
193 Message::Binary(Bytes::from(data.to_vec()))
194 }
195
196 #[inline]
197 fn close(code: u16, reason: &str) -> Self::Message {
198 Message::Close(Some(CloseFrame {
199 code,
200 reason: reason.into(),
201 }))
202 }
203}
204
205impl crate::ws::socket_sealed::Sealed for WebSocket {}
206
207impl SocketImpl for WebSocket {
208 type Message = Message;
209 type Error = Error;
210
211 #[inline]
212 #[allow(private_interfaces)]
213 fn poll_next(
214 self: Pin<&mut Self>,
215 ctx: &mut Context<'_>,
216 ) -> Poll<Option<Result<ws::Message, Self::Error>>> {
217 let Some(result) = ready!(Stream::poll_next(self, ctx)) else {
218 return Poll::Ready(None);
219 };
220
221 let message = match result {
222 Ok(message) => message,
223 Err(err) => return Poll::Ready(Some(Err(err))),
224 };
225
226 let message = match message {
227 Message::Text(..) => ws::Message::Text,
228 Message::Binary(data) => ws::Message::Binary(data),
229 Message::Ping(data) => ws::Message::Ping(data),
230 Message::Pong(data) => ws::Message::Pong(data),
231 Message::Close(..) => ws::Message::Close,
232 };
233
234 Poll::Ready(Some(Ok(message)))
235 }
236
237 #[inline]
238 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239 Sink::poll_ready(self, cx)
240 }
241
242 #[inline]
243 fn start_send(self: Pin<&mut Self>, message: Self::Message) -> Result<(), Self::Error> {
244 Sink::start_send(self, message)
245 }
246
247 #[inline]
248 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249 Sink::poll_flush(self, cx)
250 }
251}