1#![warn(
75 clippy::all,
76 clippy::dbg_macro,
77 clippy::todo,
78 clippy::empty_enum,
79 clippy::enum_glob_use,
80 clippy::mem_forget,
81 clippy::unused_self,
82 clippy::filter_map_next,
83 clippy::needless_continue,
84 clippy::needless_borrow,
85 clippy::match_wildcard_for_single_variants,
86 clippy::if_let_mutex,
87 clippy::mismatched_target_os,
88 clippy::await_holding_lock,
89 clippy::match_on_vec_items,
90 clippy::imprecise_flops,
91 clippy::suboptimal_flops,
92 clippy::lossy_float_literal,
93 clippy::rest_pat_in_fully_bound_structs,
94 clippy::fn_params_excessive_bools,
95 clippy::exit,
96 clippy::inefficient_to_string,
97 clippy::linkedlist,
98 clippy::macro_use_imports,
99 clippy::option_option,
100 clippy::verbose_file_reads,
101 clippy::unnested_or_patterns,
102 rust_2018_idioms,
103 future_incompatible,
104 nonstandard_style,
105 missing_debug_implementations,
106 missing_docs
107)]
108#![deny(unreachable_pub, private_interfaces, private_bounds)]
109#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
110#![forbid(unsafe_code)]
111#![cfg_attr(docsrs, feature(doc_cfg))]
112#![cfg_attr(test, allow(clippy::float_cmp))]
113
114use axum::{
115 async_trait,
116 extract::{ws, FromRequestParts},
117 http::request::Parts,
118 response::IntoResponse,
119};
120use futures_util::{Sink, SinkExt, Stream, StreamExt};
121use serde::{de::DeserializeOwned, Serialize};
122use std::{
123 error::Error as StdError,
124 fmt,
125 future::Future,
126 marker::PhantomData,
127 pin::Pin,
128 task::{Context, Poll},
129};
130
131#[allow(unused_macros)]
132macro_rules! with_and_without_json {
133 (
134 $(#[$m:meta])*
135 pub struct $name:ident<S, R, C = JsonCodec> {
136 $(
137 $ident:ident : $ty:ty,
138 )*
139 }
140 ) => {
141 $(#[$m])*
142 #[cfg(feature = "json")]
143 pub struct $name<S, R, C = JsonCodec> {
144 $(
145 $ident : $ty,
146 )*
147 }
148
149 $(#[$m])*
150 #[cfg(not(feature = "json"))]
151 pub struct $name<S, R, C> {
152 $(
153 $ident : $ty,
154 )*
155 }
156 }
157}
158
159with_and_without_json! {
160 pub struct WebSocketUpgrade<S, R, C = JsonCodec> {
170 upgrade: ws::WebSocketUpgrade,
171 _marker: PhantomData<fn() -> (S, R, C)>,
172 }
173}
174
175#[async_trait]
176impl<S, R, C, B> FromRequestParts<B> for WebSocketUpgrade<S, R, C>
177where
178 B: Send + Sync,
179{
180 type Rejection = <ws::WebSocketUpgrade as FromRequestParts<B>>::Rejection;
181
182 async fn from_request_parts(parts: &mut Parts, state: &B) -> Result<Self, Self::Rejection> {
183 let upgrade = FromRequestParts::from_request_parts(parts, state).await?;
184 Ok(Self {
185 upgrade,
186 _marker: PhantomData,
187 })
188 }
189}
190
191impl<S, R, C> WebSocketUpgrade<S, R, C> {
192 pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
197 where
198 F: FnOnce(WebSocket<S, R, C>) -> Fut + Send + 'static,
199 Fut: Future<Output = ()> + Send + 'static,
200 S: Send,
201 R: Send,
202 {
203 self.upgrade
204 .on_upgrade(|socket| async move {
205 let socket = WebSocket {
206 socket,
207 _marker: PhantomData,
208 };
209 callback(socket).await
210 })
211 .into_response()
212 }
213
214 pub fn map<F>(mut self, f: F) -> Self
218 where
219 F: FnOnce(ws::WebSocketUpgrade) -> ws::WebSocketUpgrade,
220 {
221 self.upgrade = f(self.upgrade);
222 self
223 }
224
225 pub fn into_inner(self) -> ws::WebSocketUpgrade {
227 self.upgrade
228 }
229}
230
231impl<S, R, C> fmt::Debug for WebSocketUpgrade<S, R, C> {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 f.debug_struct("WebSocketUpgrade")
234 .field("upgrade", &self.upgrade)
235 .finish()
236 }
237}
238
239with_and_without_json! {
240 pub struct WebSocket<S, R, C = JsonCodec> {
243 socket: ws::WebSocket,
244 _marker: PhantomData<fn() -> (S, R, C)>,
245 }
246}
247
248impl<S, R, C> WebSocket<S, R, C> {
249 pub async fn recv(&mut self) -> Option<Result<Message<R>, Error<C::Error>>>
256 where
257 R: DeserializeOwned,
258 C: Codec,
259 {
260 self.next().await
261 }
262
263 pub async fn send(&mut self, msg: Message<S>) -> Result<(), Error<C::Error>>
268 where
269 S: Serialize,
270 C: Codec,
271 {
272 SinkExt::send(self, msg).await
273 }
274
275 pub async fn close(self) -> Result<(), Error<C::Error>>
279 where
280 C: Codec,
281 {
282 self.socket.close().await.map_err(Error::Ws)
283 }
284
285 pub fn into_inner(self) -> ws::WebSocket {
287 self.socket
288 }
289}
290
291impl<S, R, C> fmt::Debug for WebSocket<S, R, C> {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 f.debug_struct("WebSocket")
294 .field("socket", &self.socket)
295 .finish()
296 }
297}
298
299impl<S, R, C> Stream for WebSocket<S, R, C>
300where
301 R: DeserializeOwned,
302 C: Codec,
303{
304 type Item = Result<Message<R>, Error<C::Error>>;
305
306 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
307 let msg = futures_util::ready!(Pin::new(&mut self.socket)
308 .poll_next(cx)
309 .map_err(Error::Ws)?);
310
311 if let Some(msg) = msg {
312 let msg = match msg {
313 ws::Message::Text(msg) => msg.into_bytes(),
314 ws::Message::Binary(bytes) => bytes,
315 ws::Message::Close(frame) => {
316 return Poll::Ready(Some(Ok(Message::Close(frame))));
317 }
318 ws::Message::Ping(buf) => {
319 return Poll::Ready(Some(Ok(Message::Ping(buf))));
320 }
321 ws::Message::Pong(buf) => {
322 return Poll::Ready(Some(Ok(Message::Pong(buf))));
323 }
324 };
325
326 let msg = C::decode(msg).map(Message::Item).map_err(Error::Codec);
327 Poll::Ready(Some(msg))
328 } else {
329 Poll::Ready(None)
330 }
331 }
332}
333
334impl<S, R, C> Sink<Message<S>> for WebSocket<S, R, C>
335where
336 S: Serialize,
337 C: Codec,
338{
339 type Error = Error<C::Error>;
340
341 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342 Pin::new(&mut self.socket).poll_ready(cx).map_err(Error::Ws)
343 }
344
345 fn start_send(mut self: Pin<&mut Self>, msg: Message<S>) -> Result<(), Self::Error> {
346 let msg = match msg {
347 Message::Item(buf) => ws::Message::Binary(C::encode(buf).map_err(Error::Codec)?),
348 Message::Ping(buf) => ws::Message::Ping(buf),
349 Message::Pong(buf) => ws::Message::Pong(buf),
350 Message::Close(frame) => ws::Message::Close(frame),
351 };
352
353 Pin::new(&mut self.socket)
354 .start_send(msg)
355 .map_err(Error::Ws)
356 }
357
358 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359 Pin::new(&mut self.socket).poll_flush(cx).map_err(Error::Ws)
360 }
361
362 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
363 Pin::new(&mut self.socket).poll_close(cx).map_err(Error::Ws)
364 }
365}
366
367pub trait Codec {
372 type Error;
374
375 fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
377 where
378 S: Serialize;
379
380 fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
382 where
383 R: DeserializeOwned;
384}
385
386#[cfg(feature = "json")]
388#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
389#[derive(Debug)]
390#[non_exhaustive]
391pub struct JsonCodec;
392
393#[cfg(feature = "json")]
394#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
395impl Codec for JsonCodec {
396 type Error = serde_json::Error;
397
398 fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
399 where
400 S: Serialize,
401 {
402 serde_json::to_vec(&msg)
403 }
404
405 fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
406 where
407 R: DeserializeOwned,
408 {
409 serde_json::from_slice(&buf)
410 }
411}
412
413#[derive(Debug)]
415pub enum Error<E> {
416 Ws(axum::Error),
418 Codec(E),
420}
421
422impl<E> fmt::Display for Error<E>
423where
424 E: fmt::Display,
425{
426 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427 match self {
428 Error::Ws(inner) => inner.fmt(f),
429 Error::Codec(inner) => inner.fmt(f),
430 }
431 }
432}
433
434impl<E> StdError for Error<E>
435where
436 E: StdError + 'static,
437{
438 fn source(&self) -> Option<&(dyn StdError + 'static)> {
439 match self {
440 Error::Ws(inner) => Some(inner),
441 Error::Codec(inner) => Some(inner),
442 }
443 }
444}
445
446#[derive(Debug, Eq, PartialEq, Clone)]
448pub enum Message<T> {
449 Item(T),
451 Ping(Vec<u8>),
455 Pong(Vec<u8>),
459 Close(Option<ws::CloseFrame<'static>>),
461}